Package paramz :: Module parameterized
[hide private]
[frames] | no frames]

Source Code for Module paramz.parameterized

  1  #=============================================================================== 
  2  # Copyright (c) 2015, Max Zwiessele 
  3  # All rights reserved. 
  4  # 
  5  # Redistribution and use in source and binary forms, with or without 
  6  # modification, are permitted provided that the following conditions are met: 
  7  # 
  8  # * Redistributions of source code must retain the above copyright notice, this 
  9  #   list of conditions and the following disclaimer. 
 10  # 
 11  # * Redistributions in binary form must reproduce the above copyright notice, 
 12  #   this list of conditions and the following disclaimer in the documentation 
 13  #   and/or other materials provided with the distribution. 
 14  # 
 15  # * Neither the name of paramax nor the names of its 
 16  #   contributors may be used to endorse or promote products derived from 
 17  #   this software without specific prior written permission. 
 18  # 
 19  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 20  # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 21  # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
 22  # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
 23  # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
 24  # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
 25  # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
 26  # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
 27  # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
 28  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
 29  #=============================================================================== 
 30   
 31  import six # For metaclass support in Python 2 and 3 simultaneously 
 32  import numpy; np = numpy 
 33  from re import compile, _pattern_type 
 34   
 35  from .core.parameter_core import Parameterizable, adjust_name_for_printing 
 36  from .core import HierarchyError 
 37   
 38  import logging 
 39  from collections import OrderedDict 
 40  from functools import reduce 
 41  logger = logging.getLogger("parameters changed meta") 
42 43 -class ParametersChangedMeta(type):
44
45 - def __call__(self, *args, **kw):
46 self._in_init_ = True 47 #import ipdb;ipdb.set_trace() 48 initialize = kw.pop('initialize', True) 49 self = super(ParametersChangedMeta, self).__call__(*args, **kw) 50 #logger.debug("finished init") 51 self._in_init_ = False 52 self._model_initialized_ = False 53 if initialize: 54 self.initialize_parameter() 55 else: 56 import warnings 57 warnings.warn("Don't forget to initialize by self.initialize_parameter()!", RuntimeWarning) 58 from .util import _inherit_doc 59 self.__doc__ = (self.__doc__ or '') + _inherit_doc(self.__class__) 60 return self
61 62 from six import with_metaclass
63 #@six.add_metaclass(ParametersChangedMeta) 64 -class Parameterized(with_metaclass(ParametersChangedMeta, Parameterizable)):
65 """ 66 Say m is a handle to a parameterized class. 67 68 Printing parameters:: 69 70 - print m: prints a nice summary over all parameters 71 - print m.name: prints details for param with name 'name' 72 - print m[regexp]: prints details for all the parameters 73 which match (!) regexp 74 - print m['']: prints details for all parameters 75 76 Fields:: 77 78 Name: The name of the param, can be renamed! 79 Value: Shape or value, if one-valued 80 Constrain: constraint of the param, curly "{c}" brackets indicate 81 some parameters are constrained by c. See detailed print 82 to get exact constraints. 83 Tied_to: which paramter it is tied to. 84 85 Getting and setting parameters:: 86 87 - Set all values in param to one: m.name.to.param = 1 88 - Set all values in parameterized: m.name[:] = 1 89 - Set values to random values: m[:] = np.random.norm(m.size) 90 91 Handling of constraining, fixing and tieing parameters:: 92 93 - You can constrain parameters by calling the constrain on the param itself, e.g: 94 95 - m.name[:,1].constrain_positive() 96 - m.name[0].tie_to(m.name[1]) 97 98 - Fixing parameters will fix them to the value they are right now. If you change 99 the parameters value, the param will be fixed to the new value! 100 101 - If you want to operate on all parameters use m[''] to wildcard select all paramters 102 and concatenate them. Printing m[''] will result in printing of all parameters in detail. 103 """ 104 #=========================================================================== 105 # Metaclass for parameters changed after init. 106 # This makes sure, that parameters changed will always be called after __init__ 107 # **Never** call parameters_changed() yourself 108 #This is ignored in Python 3 -- you need to put the meta class in the function definition. 109 #__metaclass__ = ParametersChangedMeta 110 #The six module is used to support both Python 2 and 3 simultaneously 111 #===========================================================================
112 - def __init__(self, name=None, parameters=[]):
113 super(Parameterized, self).__init__(name=name) 114 self.size = sum(p.size for p in self.parameters) 115 self.add_observer(self, self._parameters_changed_notification, -100) 116 self._fixes_ = None 117 self._param_slices_ = [] 118 #self._connect_parameters() 119 self.link_parameters(*parameters)
120 121 #=========================================================================== 122 # Add remove parameters: 123 #=========================================================================== 143 param.traverse_parents(visit, self) 144 param._parent_.unlink_parameter(param) 145 # make sure the size is set 146 if index is None: 147 start = sum(p.size for p in self.parameters) 148 for name, iop in self._index_operations.items(): 149 iop.shift_right(start, param.size) 150 iop.update(param._index_operations[name], self.size) 151 param._parent_ = self 152 param._parent_index_ = len(self.parameters) 153 self.parameters.append(param) 154 else: 155 start = sum(p.size for p in self.parameters[:index]) 156 for name, iop in self._index_operations.items(): 157 iop.shift_right(start, param.size) 158 iop.update(param._index_operations[name], start) 159 param._parent_ = self 160 param._parent_index_ = index if index>=0 else len(self.parameters[:index]) 161 for p in self.parameters[index:]: 162 p._parent_index_ += 1 163 self.parameters.insert(index, param) 164 165 param.add_observer(self, self._pass_through_notify_observers, -np.inf) 166 167 parent = self 168 while parent is not None: 169 parent.size += param.size 170 parent = parent._parent_ 171 self._notify_parent_change() 172 173 if not self._in_init_ and self._highest_parent_._model_initialized_: 174 #self._connect_parameters() 175 #self._notify_parent_change() 176 177 self._highest_parent_._connect_parameters() 178 self._highest_parent_._notify_parent_change() 179 self._highest_parent_._connect_fixes() 180 return param 181 else: 182 raise HierarchyError("""Parameter exists already, try making a copy""")
183 184 191 224
225 - def _connect_parameters(self, ignore_added_names=False):
226 # connect parameterlist to this parameterized object 227 # This just sets up the right connection for the params objects 228 # to be used as parameters 229 # it also sets the constraints for each parameter to the constraints 230 # of their respective parents 231 self._model_initialized_ = True 232 233 if not hasattr(self, "parameters") or len(self.parameters) < 1: 234 # no parameters for this class 235 return 236 237 old_size = 0 238 self._param_slices_ = [] 239 for i, p in enumerate(self.parameters): 240 if not p.param_array.flags['C_CONTIGUOUS']:# getattr(p, 'shape', None) != getattr(p, '_realshape_', None): 241 raise ValueError(""" 242 Have you added an additional dimension to a Param object? 243 244 p[:,None], where p is of type Param does not work 245 and is expected to fail! Try increasing the 246 dimensionality of the param array before making 247 a Param out of it: 248 p = Param("<name>", array[:,None]) 249 250 Otherwise this should not happen! 251 Please write an email to the developers with the code, 252 which reproduces this error. 253 All parameter arrays must be C_CONTIGUOUS 254 """) 255 256 p._parent_ = self 257 p._parent_index_ = i 258 259 pslice = slice(old_size, old_size + p.size) 260 261 # first connect all children 262 p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice]) 263 264 # then connect children to self 265 self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C') 266 self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C') 267 268 p.param_array.data = self.param_array[pslice].data 269 p.gradient_full.data = self.gradient_full[pslice].data 270 271 self._param_slices_.append(pslice) 272 273 self._add_parameter_name(p) 274 old_size += p.size
275 276 #=========================================================================== 277 # Get/set parameters: 278 #===========================================================================
279 - def grep_param_names(self, regexp):
280 """ 281 create a list of parameters, matching regular expression regexp 282 """ 283 if not isinstance(regexp, _pattern_type): regexp = compile(regexp) 284 found_params = [] 285 def visit(innerself, regexp): 286 if (innerself is not self) and regexp.match(innerself.hierarchy_name().partition('.')[2]): 287 found_params.append(innerself)
288 self.traverse(visit, regexp) 289 return found_params 290
291 - def __getitem__(self, name, paramlist=None):
292 if isinstance(name, (int, slice, tuple, np.ndarray)): 293 return self.param_array[name] 294 else: 295 paramlist = self.grep_param_names(name) 296 if len(paramlist) < 1: raise AttributeError(name) 297 if len(paramlist) == 1: 298 #if isinstance(paramlist[-1], Parameterized) and paramlist[-1].size > 0: 299 # paramlist = paramlist[-1].flattened_parameters 300 # if len(paramlist) != 1: 301 # return ParamConcatenation(paramlist) 302 return paramlist[-1] 303 304 from .param import ParamConcatenation 305 return ParamConcatenation(paramlist)
306
307 - def __setitem__(self, name, value, paramlist=None):
308 if not self._model_initialized_: 309 raise AttributeError("""Model is not initialized, this change will only be reflected after initialization if in leaf. 310 311 If you are loading a model, set updates off, then initialize, then set the values, then update the model to be fully initialized: 312 >>> m.update_model(False) 313 >>> m.initialize_parameter() 314 >>> m[:] = loaded_parameters 315 >>> m.update_model(True) 316 """) 317 if value is None: 318 return # nothing to do here 319 if isinstance(name, (slice, tuple, np.ndarray)): 320 try: 321 self.param_array[name] = value 322 except: 323 raise ValueError("Setting by slice or index only allowed with array-like") 324 self.trigger_update() 325 else: 326 param = self.__getitem__(name, paramlist) 327 param[:] = value
328
329 - def __setattr__(self, name, val):
330 # override the default behaviour, if setting a param, so broadcasting can by used 331 if hasattr(self, "parameters"): 332 pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False) 333 if name in pnames: 334 param = self.parameters[pnames.index(name)] 335 param[:] = val; return 336 return object.__setattr__(self, name, val);
337 338 #=========================================================================== 339 # Pickling 340 #===========================================================================
341 - def __setstate__(self, state):
342 super(Parameterized, self).__setstate__(state) 343 self._connect_parameters() 344 self._connect_fixes() 345 self._notify_parent_change() 346 self.parameters_changed() 347 return self
348
349 - def copy(self, memo=None):
350 if memo is None: 351 memo = {} 352 memo[id(self.optimizer_array)] = None # and param_array 353 memo[id(self.param_array)] = None # and param_array 354 copy = super(Parameterized, self).copy(memo) 355 copy._connect_parameters() 356 copy._connect_fixes() 357 copy._notify_parent_change() 358 return copy
359 360 #=========================================================================== 361 # Printing: 362 #===========================================================================
363 - def _short(self):
364 return self.hierarchy_name()
365 @property
366 - def flattened_parameters(self):
367 return [xi for x in self.parameters for xi in x.flattened_parameters]
368
369 - def get_property_string(self, propname):
370 props = [] 371 for p in self.parameters: 372 props.extend(p.get_property_string(propname)) 373 return props
374 375 @property
376 - def _description_str(self):
377 return [xi for x in self.parameters for xi in x._description_str]
378
379 - def _repr_html_(self, header=True):
380 """Representation of the parameters in html for notebook display.""" 381 name = adjust_name_for_printing(self.name) + "." 382 names = self.parameter_names() 383 desc = self._description_str 384 iops = OrderedDict() 385 for opname in self._index_operations: 386 iop = [] 387 for p in self.parameters: 388 iop.extend(p.get_property_string(opname)) 389 iops[opname] = iop 390 391 format_spec = self._format_spec(name, names, desc, iops, False) 392 to_print = [] 393 394 if header: 395 to_print.append("<tr><th><b>" + '</b></th><th><b>'.join(format_spec).format(name=name, desc='value', **dict((name, name) for name in iops)) + "</b></th></tr>") 396 397 format_spec = "<tr><td class=tg-left>" + format_spec[0] + '</td><td class=tg-right>' + format_spec[1] + '</td><td class=tg-center>' + '</td><td class=tg-center>'.join(format_spec[2:]) + "</td></tr>" 398 for i in range(len(names)): 399 to_print.append(format_spec.format(name=names[i], desc=desc[i], **dict((name, iops[name][i]) for name in iops))) 400 401 style = """<style type="text/css"> 402 .tg {font-family:"Courier New", Courier, monospace !important;padding:2px 3px;word-break:normal;border-collapse:collapse;border-spacing:0;border-color:#DCDCDC;margin:0px auto;width:100%;} 403 .tg td{font-family:"Courier New", Courier, monospace !important;font-weight:bold;color:#444;background-color:#F7FDFA;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;} 404 .tg th{font-family:"Courier New", Courier, monospace !important;font-weight:normal;color:#fff;background-color:#26ADE4;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;} 405 .tg .tg-left{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:left;} 406 .tg .tg-center{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:center;} 407 .tg .tg-right{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:right;} 408 </style>""" 409 return style + '\n' + '<table class="tg">' + '\n'.join(to_print) + '\n</table>'
410
411 - def _format_spec(self, name, names, desc, iops, VT100=True):
412 nl = max([len(str(x)) for x in names + [name]]) 413 sl = max([len(str(x)) for x in desc + ["value"]]) 414 415 lls = [reduce(lambda a,b: max(a, len(b)), iops[opname], len(opname)) for opname in iops] 416 417 if VT100: 418 format_spec = [" \033[1m{{name!s:<{0}}}\033[0;0m".format(nl),"{{desc!s:>{0}}}".format(sl)] 419 else: 420 format_spec = [" {{name!s:<{0}}}".format(nl),"{{desc!s:>{0}}}".format(sl)] 421 422 for opname, l in zip(iops, lls): 423 f = '{{{1}!s:^{0}}}'.format(l, opname) 424 format_spec.append(f) 425 426 return format_spec
427
428 - def __str__(self, header=True, VT100=True):
429 name = adjust_name_for_printing(self.name) + "." 430 names = self.parameter_names(adjust_for_printing=True) 431 desc = self._description_str 432 iops = OrderedDict() 433 for opname in self._index_operations: 434 iops[opname] = self.get_property_string(opname) 435 436 format_spec = ' | '.join(self._format_spec(name, names, desc, iops, VT100)) 437 438 to_print = [] 439 440 if header: 441 to_print.append(format_spec.format(name=name, desc='value', **dict((name, name) for name in iops))) 442 443 for i in range(len(names)): 444 to_print.append(format_spec.format(name=names[i], desc=desc[i], **dict((name, iops[name][i]) for name in iops))) 445 return '\n'.join(to_print)
446
447 - def build_pydot(self, G=None): # pragma: no cover
448 """ 449 Build a pydot representation of this model. This needs pydot installed. 450 451 Example Usage:: 452 453 np.random.seed(1000) 454 X = np.random.normal(0,1,(20,2)) 455 beta = np.random.uniform(0,1,(2,1)) 456 Y = X.dot(beta) 457 m = RidgeRegression(X, Y) 458 G = m.build_pydot() 459 G.write_png('example_hierarchy_layout.png') 460 461 The output looks like: 462 463 .. image:: ./example_hierarchy_layout.png 464 465 Rectangles are parameterized objects (nodes or leafs of hierarchy). 466 467 Trapezoids are param objects, which represent the arrays for parameters. 468 469 Black arrows show parameter hierarchical dependence. The arrow points 470 from parents towards children. 471 472 Orange arrows show the observer pattern. Self references (here) are 473 the references to the call to parameters changed and references upwards 474 are the references to tell the parents they need to update. 475 """ 476 import pydot # @UnresolvedImport 477 iamroot = False 478 if G is None: 479 G = pydot.Dot(graph_type='digraph', bgcolor=None) 480 iamroot=True 481 node = pydot.Node(id(self), shape='box', label=self.name)#, color='white') 482 483 G.add_node(node) 484 for child in self.parameters: 485 child_node = child.build_pydot(G) 486 G.add_edge(pydot.Edge(node, child_node))#, color='white')) 487 488 for _, o, _ in self.observers: 489 label = o.name if hasattr(o, 'name') else str(o) 490 observed_node = pydot.Node(id(o), label=label) 491 if str(id(o)) not in G.obj_dict['nodes']: 492 G.add_node(observed_node) 493 edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee') 494 G.add_edge(edge) 495 496 if iamroot: 497 return G 498 return node 499