Package paramz :: Package core :: Module indexable
[hide private]
[frames] | no frames]

Source Code for Module paramz.core.indexable

  1  #=============================================================================== 
  2  # Copyright (c) 2015, Max Zwiessele 
  3  # All rights reserved. 
  4  # 
  5  # Redistribution and use in source and binary forms, with or without 
  6  # modification, are permitted provided that the following conditions are met: 
  7  # 
  8  # * Redistributions of source code must retain the above copyright notice, this 
  9  #   list of conditions and the following disclaimer. 
 10  # 
 11  # * Redistributions in binary form must reproduce the above copyright notice, 
 12  #   this list of conditions and the following disclaimer in the documentation 
 13  #   and/or other materials provided with the distribution. 
 14  # 
 15  # * Neither the name of paramz.core.indexable nor the names of its 
 16  #   contributors may be used to endorse or promote products derived from 
 17  #   this software without specific prior written permission. 
 18  # 
 19  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 20  # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 21  # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
 22  # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
 23  # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
 24  # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
 25  # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
 26  # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
 27  # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
 28  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 29  #=============================================================================== 
 30  import numpy as np 
 31  from .nameable import Nameable 
 32  from .updateable import Updateable 
 33  from ..transformations import __fixed__ 
 34  from operator import delitem 
 35  from functools import reduce 
 36  from collections import OrderedDict 
 37   
38 -class Indexable(Nameable, Updateable):
39 """ 40 Make an object constrainable with Priors and Transformations. 41 42 TODO: Mappings!! (As in ties etc.) 43 44 Adding a constraint to a Parameter means to tell the highest parent that 45 the constraint was added and making sure that all parameters covered 46 by this object are indeed conforming to the constraint. 47 48 :func:`constrain()` and :func:`unconstrain()` are main methods here 49 """
50 - def __init__(self, name, default_constraint=None, *a, **kw):
51 super(Indexable, self).__init__(name=name, *a, **kw) 52 self._index_operations = OrderedDict()
53
54 - def __setstate__(self, state):
55 super(Indexable, self).__setstate__(state) 56 # old index operations: 57 if not hasattr(self, "_index_operations"): 58 self._index_operations = OrderedDict() 59 from paramz.core.index_operations import ParameterIndexOperations, ParameterIndexOperationsView 60 from paramz import Param 61 if isinstance(self, Param): 62 state = state[1] 63 for name, io in state.iteritems(): 64 if isinstance(io, (ParameterIndexOperations, ParameterIndexOperationsView)): 65 self._index_operations[name] = io 66 for name in self._index_operations: 67 self._add_io(name, self._index_operations[name])
68 69 #@property 70 #def _index_operations(self): 71 # try: 72 # return self._index_operations_dict 73 # except AttributeError: 74 # self._index_operations_dict = OrderedDict() 75 # return self._index_operations_dict 76 #@_index_operations.setter 77 #def _index_operations(self, io): 78 # self._index_operations_dict = io 79
80 - def add_index_operation(self, name, operations):
81 """ 82 Add index operation with name to the operations given. 83 84 raises: attribute error if operations exist. 85 """ 86 if name not in self._index_operations: 87 self._add_io(name, operations) 88 else: 89 raise AttributeError("An index operation with the name {} was already taken".format(name))
90
91 - def _add_io(self, name, operations):
92 self._index_operations[name] = operations 93 def do_raise(self, x): 94 self._index_operations.__setitem__(name, x) 95 self._connect_fixes() 96 self._notify_parent_change()
97 #raise AttributeError("Cannot set {name} directly, use the appropriate methods to set new {name}".format(name=name)) 98 99 setattr(Indexable, name, property(fget=lambda self: self._index_operations[name], 100 fset=do_raise))
101
102 - def remove_index_operation(self, name):
103 if name in self._index_operations: 104 delitem(self._index_operations, name) 105 #delattr(self, name) 106 else: 107 raise AttributeError("No index operation with the name {}".format(name))
108
109 - def _disconnect_parent(self, *args, **kw):
110 """ 111 From Parentable: 112 disconnect the parent and set the new constraints to constr 113 """ 114 for name, iop in list(self._index_operations.items()): 115 iopc = iop.copy() 116 iop.clear() 117 self.remove_index_operation(name) 118 self.add_index_operation(name, iopc) 119 #self.constraints.clear() 120 #self.constraints = constr 121 self._parent_ = None 122 self._parent_index_ = None 123 self._connect_fixes() 124 self._notify_parent_change()
125 126 #=========================================================================== 127 # Indexable 128 #===========================================================================
129 - def _offset_for(self, param):
130 """ 131 Return the offset of the param inside this parameterized object. 132 This does not need to account for shaped parameters, as it 133 basically just sums up the parameter sizes which come before param. 134 """ 135 if param.has_parent(): 136 p = param._parent_._get_original(param) 137 if p in self.parameters: 138 return reduce(lambda a,b: a + b.size, self.parameters[:p._parent_index_], 0) 139 return self._offset_for(param._parent_) + param._parent_._offset_for(param) 140 return 0
141 142 ### Global index operations (from highest_parent) 143 ### These indices are for gradchecking, so that we 144 ### can index the optimizer array and manipulate it directly 145 ### The indices here do not reflect the indices in 146 ### index_operations, as index operations handle 147 ### the offset themselves and can be set directly 148 ### without doing the offset.
149 - def _raveled_index_for(self, param):
150 """ 151 get the raveled index for a param 152 that is an int array, containing the indexes for the flattened 153 param inside this parameterized logic. 154 155 !Warning! be sure to call this method on the highest parent of a hierarchy, 156 as it uses the fixes to do its work 157 """ 158 from ..param import ParamConcatenation 159 if isinstance(param, ParamConcatenation): 160 return np.hstack((self._raveled_index_for(p) for p in param.params)) 161 return param._raveled_index() + self._offset_for(param)
162
163 - def _raveled_index_for_transformed(self, param):
164 """ 165 get the raveled index for a param for the transformed parameter array 166 (optimizer array). 167 168 that is an int array, containing the indexes for the flattened 169 param inside this parameterized logic. 170 171 !Warning! be sure to call this method on the highest parent of a hierarchy, 172 as it uses the fixes to do its work. If you do not know 173 what you are doing, do not use this method, it will have 174 unexpected returns! 175 """ 176 ravi = self._raveled_index_for(param) 177 if self._has_fixes(): 178 fixes = self._fixes_ 179 ### Transformed indices, handling the offsets of previous fixes 180 transformed = (np.r_[:self.size] - (~fixes).cumsum()) 181 return transformed[ravi[fixes[ravi]]] 182 else: 183 return ravi
184 185 ### These indices are just the raveled index for self 186 ### These are in the index_operations are used for them 187 ### The index_operations then handle the offsets themselves 188 ### This makes it easier to test and handle indices 189 ### as the index operations framework is in its own 190 ### corner and can be set significantly better without 191 ### being inside the parameterized scope.
192 - def _raveled_index(self):
193 """ 194 Flattened array of ints, specifying the index of this object. 195 This has to account for shaped parameters! 196 """ 197 return np.r_[:self.size]
198 ###### 199 200 201 #=========================================================================== 202 # Tie parameters together 203 # TODO: create own class for tieing and remapping 204 #=========================================================================== 205 # def _has_ties(self): 206 # if self._highest_parent_.tie.tied_param is None: 207 # return False 208 # if self.has_parent(): 209 # return self._highest_parent_.tie.label_buf[self._highest_parent_._raveled_index_for(self)].sum()>0 210 # return True 211 # 212 # def tie_together(self): 213 # self._highest_parent_.tie.add_tied_parameter(self) 214 # self._highest_parent_._set_fixed(self,self._raveled_index()) 215 # self._trigger_params_changed() 216 #=============================================================================== 217 218
219 - def _parent_changed(self, parent):
220 """ 221 From Parentable: 222 Called when the parent changed 223 224 update the constraints and priors view, so that 225 constraining is automized for the parent. 226 """ 227 from .index_operations import ParameterIndexOperationsView 228 #if getattr(self, "_in_init_"): 229 #import ipdb;ipdb.set_trace() 230 #self.constraints.update(param.constraints, start) 231 #self.priors.update(param.priors, start) 232 offset = parent._offset_for(self) 233 for name, iop in list(self._index_operations.items()): 234 self.remove_index_operation(name) 235 self.add_index_operation(name, ParameterIndexOperationsView(parent._index_operations[name], offset, self.size)) 236 self._fixes_ = None 237 for p in self.parameters: 238 p._parent_changed(parent)
239
240 - def _add_to_index_operations(self, which, reconstrained, what, warning):
241 """ 242 Helper preventing copy code. 243 This adds the given what (transformation, prior etc) to parameter index operations which. 244 reconstrained are reconstrained indices. 245 warn when reconstraining parameters if warning is True. 246 TODO: find out which parameters have changed specifically 247 """ 248 if warning and reconstrained.size > 0: 249 # TODO: figure out which parameters have changed and only print those 250 print("WARNING: reconstraining parameters {}".format(self.hierarchy_name() or self.name)) 251 index = self._raveled_index() 252 which.add(what, index) 253 return index
254
255 - def _remove_from_index_operations(self, which, transforms):
256 """ 257 Helper preventing copy code. 258 Remove given what (transform prior etc) from which param index ops. 259 """ 260 if len(transforms) == 0: 261 transforms = which.properties() 262 removed = np.empty((0,), dtype=int) 263 for t in list(transforms): 264 unconstrained = which.remove(t, self._raveled_index()) 265 removed = np.union1d(removed, unconstrained) 266 if t is __fixed__: 267 self._highest_parent_._set_unfixed(self, unconstrained) 268 269 return removed
270