Package paramz :: Module ties_and_remappings
[hide private]
[frames] | no frames]

Source Code for Module paramz.ties_and_remappings

  1  #=============================================================================== 
  2  # Copyright (c) 2012 - 2014, GPy authors (see AUTHORS.txt). 
  3  # Copyright (c) 2014, James Hensman, Max Zwiessele 
  4  # Copyright (c) 2015, Max Zwiessele 
  5  # 
  6  # All rights reserved. 
  7  #  
  8  # Redistribution and use in source and binary forms, with or without 
  9  # modification, are permitted provided that the following conditions are met: 
 10  #  
 11  # * Redistributions of source code must retain the above copyright notice, this 
 12  #   list of conditions and the following disclaimer. 
 13  #  
 14  # * Redistributions in binary form must reproduce the above copyright notice, 
 15  #   this list of conditions and the following disclaimer in the documentation 
 16  #   and/or other materials provided with the distribution. 
 17  #  
 18  # * Neither the name of paramax nor the names of its 
 19  #   contributors may be used to endorse or promote products derived from 
 20  #   this software without specific prior written permission. 
 21  #  
 22  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 23  # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 24  # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
 25  # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
 26  # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
 27  # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
 28  # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
 29  # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
 30  # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
 31  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 32  #=============================================================================== 
 33   
 34  import numpy as np 
 35  from .parameterized import Parameterized 
 36  from .param import Param 
 37   
38 -class Remapping(Parameterized):
39 - def mapping(self):
40 """ 41 The return value of this function gives the values which the re-mapped 42 parameters should take. Implement in sub-classes. 43 """ 44 raise NotImplementedError
45
46 - def callback(self):
47 raise NotImplementedError
48
49 - def __str__(self):
50 return self.name
51
52 - def parameters_changed(self):
53 #ensure all out parameters have the correct value, as specified by our mapping 54 index = self._highest_parent_.constraints[self] 55 self._highest_parent_.param_array[index] = self.mapping() 56 [p.notify_observers(which=self) for p in self.tied_parameters]
57
58 -class Fix(Remapping):
59 pass
60 61 62 63
64 -class Tie(Parameterized):
65 """ 66 The new parameter tie framework. (under development) 67 68 All the parameters tied together get a new parameter inside the *Tie* object. 69 Its value should always be equal to all the tied parameters, and its gradient 70 is the sum of all the tied parameters. 71 72 Implementation Details: 73 74 The *Tie* object should only exist on the top of param tree (the highest parent). 75 76 self.label_buf: 77 It uses a label buffer that has the same length as all the parameters (self._highest_parent_.param_array). 78 The buffer keeps track of all the tied parameters. All the tied parameters have a label (an interger) higher 79 than 0, and the parameters that have the same label are tied together. 80 81 self.buf_index: 82 An auxiliary index list for the global index of the tie parameter inside the *Tie* object. 83 84 .. warning:: 85 86 This is not implemented yet and will be different, with high degree of certainty. 87 Do not use! 88 89 """
90 - def __init__(self, name='tie'):
91 super(Tie, self).__init__(name) 92 self.tied_param = None 93 # The buffer keeps track of tie status 94 self.label_buf = None 95 # The global indices of the 'tied' param 96 self.buf_idx = None 97 # A boolean array indicating non-tied parameters 98 self._tie_ = None
99
100 - def getTieFlag(self, p=None):
101 if self.tied_param is None: 102 if self._tie_ is None or self._tie_.size != self._highest_parent_.param_array.size: 103 self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool) 104 if p is not None: 105 return self._tie_[p._highest_parent_._raveled_index_for(p)] 106 return self._tie_
107
108 - def _init_labelBuf(self):
109 if self.label_buf is None: 110 self.label_buf = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int) 111 if self._tie_ is None or self._tie_.size != self._highest_parent_.param_array.size: 112 self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
113
114 - def _updateTieFlag(self):
115 if self._tie_.size != self.label_buf.size: 116 self._tie_ = np.ones((self._highest_parent_.param_array.size,),dtype=np.bool) 117 self._tie_[self.label_buf>0] = False 118 self._tie_[self.buf_idx] = True
119
120 - def add_tied_parameter(self, p, p2=None):
121 """ 122 Tie the list of parameters p together (p2==None) or 123 Tie the list of parameters p with the list of parameters p2 (p2!=None) 124 """ 125 self._init_labelBuf() 126 if p2 is None: 127 idx = self._highest_parent_._raveled_index_for(p) 128 val = self._sync_val_group(idx) 129 if np.all(self.label_buf[idx]==0): 130 # None of p has been tied before. 131 tie_idx = self._expandTieParam(1) 132 print(tie_idx) 133 tie_id = self.label_buf.max()+1 134 self.label_buf[tie_idx] = tie_id 135 else: 136 b = self.label_buf[idx] 137 ids = np.unique(b[b>0]) 138 tie_id, tie_idx = self._merge_tie_param(ids) 139 self._highest_parent_.param_array[tie_idx] = val 140 idx = self._highest_parent_._raveled_index_for(p) 141 self.label_buf[idx] = tie_id 142 else: 143 pass 144 self._updateTieFlag()
145
146 - def _merge_tie_param(self, ids):
147 """Merge the tie parameters with ids in the list.""" 148 if len(ids)==1: 149 id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==ids[0]][0] 150 return ids[0],id_final_idx 151 id_final = ids[0] 152 ids_rm = ids[1:] 153 label_buf_param = self.label_buf[self.buf_idx] 154 idx_param = [np.where(label_buf_param==i)[0][0] for i in ids_rm] 155 self._removeTieParam(idx_param) 156 [np.put(self.label_buf, np.where(self.label_buf==i), id_final) for i in ids_rm] 157 id_final_idx = self.buf_idx[self.label_buf[self.buf_idx]==id_final][0] 158 return id_final, id_final_idx
159
160 - def _sync_val_group(self, idx):
161 self._highest_parent_.param_array[idx] = self._highest_parent_.param_array[idx].mean() 162 return self._highest_parent_.param_array[idx][0]
163
164 - def _expandTieParam(self, num):
165 """Expand the tie param with the number of *num* parameters""" 166 if self.tied_param is None: 167 new_buf = np.empty((num,)) 168 else: 169 new_buf = np.empty((self.tied_param.size+num,)) 170 new_buf[:self.tied_param.size] = self.tied_param.param_array.copy() 171 self.remove_parameter(self.tied_param) 172 self.tied_param = Param('tied',new_buf) 173 self.add_parameter(self.tied_param) 174 buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param) 175 self._expand_label_buf(self.buf_idx, buf_idx_new) 176 self.buf_idx = buf_idx_new 177 return self.buf_idx[-num:]
178
179 - def _removeTieParam(self, idx):
180 """idx within tied_param""" 181 new_buf = np.empty((self.tied_param.size-len(idx),)) 182 bool_list = np.ones((self.tied_param.size,),dtype=np.bool) 183 bool_list[idx] = False 184 new_buf[:] = self.tied_param.param_array[bool_list] 185 self.remove_parameter(self.tied_param) 186 self.tied_param = Param('tied',new_buf) 187 self.add_parameter(self.tied_param) 188 buf_idx_new = self._highest_parent_._raveled_index_for(self.tied_param) 189 self._shrink_label_buf(self.buf_idx, buf_idx_new, bool_list) 190 self.buf_idx = buf_idx_new
191
192 - def _expand_label_buf(self, idx_old, idx_new):
193 """Expand label buffer accordingly""" 194 if idx_old is None: 195 self.label_buf = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int) 196 else: 197 bool_old = np.zeros((self.label_buf.size,),dtype=np.bool) 198 bool_old[idx_old] = True 199 bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool) 200 bool_new[idx_new] = True 201 label_buf_new = np.zeros(self._highest_parent_.param_array.shape, dtype=np.int) 202 label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)] 203 label_buf_new[idx_new[:len(idx_old)]] = self.label_buf[idx_old] 204 self.label_buf = label_buf_new
205
206 - def _shrink_label_buf(self, idx_old, idx_new, bool_list):
207 bool_old = np.zeros((self.label_buf.size,),dtype=np.bool) 208 bool_old[idx_old] = True 209 bool_new = np.zeros((self._highest_parent_.param_array.size,),dtype=np.bool) 210 bool_new[idx_new] = True 211 label_buf_new = np.empty(self._highest_parent_.param_array.shape, dtype=np.int) 212 label_buf_new[np.logical_not(bool_new)] = self.label_buf[np.logical_not(bool_old)] 213 label_buf_new[idx_new] = self.label_buf[idx_old[bool_list]] 214 self.label_buf = label_buf_new
215
216 - def _check_change(self):
217 changed = False 218 if self.tied_param is not None: 219 for i in range(self.tied_param.size): 220 b0 = self.label_buf==self.label_buf[self.buf_idx[i]] 221 b = self._highest_parent_.param_array[b0]!=self.tied_param[i] 222 if b.sum()==0: 223 print('XXX') 224 continue 225 elif b.sum()==1: 226 print('!!!') 227 val = self._highest_parent_.param_array[b0][b][0] 228 self._highest_parent_.param_array[b0] = val 229 else: 230 print('@@@') 231 self._highest_parent_.param_array[b0] = self.tied_param[i] 232 changed = True 233 return changed
234
235 - def parameters_changed(self):
236 #ensure all out parameters have the correct value, as specified by our mapping 237 changed = self._check_change() 238 if changed: 239 self._highest_parent_._trigger_params_changed() 240 self.collate_gradient()
241
242 - def collate_gradient(self):
243 if self.tied_param is not None: 244 self.tied_param.gradient = 0. 245 [np.put(self.tied_param.gradient, i, self._highest_parent_.gradient[self.label_buf==self.label_buf[self.buf_idx[i]]].sum()) 246 for i in range(self.tied_param.size)]
247
248 - def propagate_val(self):
249 if self.tied_param is not None: 250 for i in range(self.tied_param.size): 251 self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
252