| Home | Trees | Indices | Help |
|---|
|
|
1 #===============================================================================
2 # Copyright (c) 2015, Max Zwiessele
3 # All rights reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 # list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 # this list of conditions and the following disclaimer in the documentation
13 # and/or other materials provided with the distribution.
14 #
15 # * Neither the name of paramax nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #===============================================================================
30
31 import six # For metaclass support in Python 2 and 3 simultaneously
32 import numpy; np = numpy
33 from re import compile, _pattern_type
34
35 from .core.parameter_core import Parameterizable, adjust_name_for_printing
36 from .core import HierarchyError
37
38 import logging
39 from collections import OrderedDict
40 from functools import reduce
41 logger = logging.getLogger("parameters changed meta")
44
46 self._in_init_ = True
47 #import ipdb;ipdb.set_trace()
48 initialize = kw.pop('initialize', True)
49 self = super(ParametersChangedMeta, self).__call__(*args, **kw)
50 #logger.debug("finished init")
51 self._in_init_ = False
52 self._model_initialized_ = False
53 if initialize:
54 self.initialize_parameter()
55 else:
56 import warnings
57 warnings.warn("Don't forget to initialize by self.initialize_parameter()!", RuntimeWarning)
58 from .util import _inherit_doc
59 self.__doc__ = (self.__doc__ or '') + _inherit_doc(self.__class__)
60 return self
61
62 from six import with_metaclass
63 #@six.add_metaclass(ParametersChangedMeta)
64 -class Parameterized(with_metaclass(ParametersChangedMeta, Parameterizable)):
65 """
66 Say m is a handle to a parameterized class.
67
68 Printing parameters::
69
70 - print m: prints a nice summary over all parameters
71 - print m.name: prints details for param with name 'name'
72 - print m[regexp]: prints details for all the parameters
73 which match (!) regexp
74 - print m['']: prints details for all parameters
75
76 Fields::
77
78 Name: The name of the param, can be renamed!
79 Value: Shape or value, if one-valued
80 Constrain: constraint of the param, curly "{c}" brackets indicate
81 some parameters are constrained by c. See detailed print
82 to get exact constraints.
83 Tied_to: which paramter it is tied to.
84
85 Getting and setting parameters::
86
87 - Set all values in param to one: m.name.to.param = 1
88 - Set all values in parameterized: m.name[:] = 1
89 - Set values to random values: m[:] = np.random.norm(m.size)
90
91 Handling of constraining, fixing and tieing parameters::
92
93 - You can constrain parameters by calling the constrain on the param itself, e.g:
94
95 - m.name[:,1].constrain_positive()
96 - m.name[0].tie_to(m.name[1])
97
98 - Fixing parameters will fix them to the value they are right now. If you change
99 the parameters value, the param will be fixed to the new value!
100
101 - If you want to operate on all parameters use m[''] to wildcard select all paramters
102 and concatenate them. Printing m[''] will result in printing of all parameters in detail.
103 """
104 #===========================================================================
105 # Metaclass for parameters changed after init.
106 # This makes sure, that parameters changed will always be called after __init__
107 # **Never** call parameters_changed() yourself
108 #This is ignored in Python 3 -- you need to put the meta class in the function definition.
109 #__metaclass__ = ParametersChangedMeta
110 #The six module is used to support both Python 2 and 3 simultaneously
111 #===========================================================================
113 super(Parameterized, self).__init__(name=name)
114 self.size = sum(p.size for p in self.parameters)
115 self.add_observer(self, self._parameters_changed_notification, -100)
116 self._fixes_ = None
117 self._param_slices_ = []
118 #self._connect_parameters()
119 self.link_parameters(*parameters)
120
121 #===========================================================================
122 # Add remove parameters:
123 #===========================================================================
125 """
126 :param parameters: the parameters to add
127 :type parameters: list of or one :py:class:`paramz.param.Param`
128 :param [index]: index of where to put parameters
129
130 Add all parameters to this param class, you can insert parameters
131 at any given index using the :func:`list.insert` syntax
132 """
133 if param in self.parameters and index is not None:
134 self.unlink_parameter(param)
135 return self.link_parameter(param, index)
136 # elif param.has_parent():
137 # raise HierarchyError, "parameter {} already in another model ({}), create new object (or copy) for adding".format(param._short(), param._highest_parent_._short())
138 elif param not in self.parameters:
139 if param.has_parent():
140 def visit(parent, self):
141 if parent is self:
142 raise HierarchyError("You cannot add a parameter twice into the hierarchy")
143 param.traverse_parents(visit, self)
144 param._parent_.unlink_parameter(param)
145 # make sure the size is set
146 if index is None:
147 start = sum(p.size for p in self.parameters)
148 for name, iop in self._index_operations.items():
149 iop.shift_right(start, param.size)
150 iop.update(param._index_operations[name], self.size)
151 param._parent_ = self
152 param._parent_index_ = len(self.parameters)
153 self.parameters.append(param)
154 else:
155 start = sum(p.size for p in self.parameters[:index])
156 for name, iop in self._index_operations.items():
157 iop.shift_right(start, param.size)
158 iop.update(param._index_operations[name], start)
159 param._parent_ = self
160 param._parent_index_ = index if index>=0 else len(self.parameters[:index])
161 for p in self.parameters[index:]:
162 p._parent_index_ += 1
163 self.parameters.insert(index, param)
164
165 param.add_observer(self, self._pass_through_notify_observers, -np.inf)
166
167 parent = self
168 while parent is not None:
169 parent.size += param.size
170 parent = parent._parent_
171 self._notify_parent_change()
172
173 if not self._in_init_ and self._highest_parent_._model_initialized_:
174 #self._connect_parameters()
175 #self._notify_parent_change()
176
177 self._highest_parent_._connect_parameters()
178 self._highest_parent_._notify_parent_change()
179 self._highest_parent_._connect_fixes()
180 return param
181 else:
182 raise HierarchyError("""Parameter exists already, try making a copy""")
183
184
186 """
187 convenience method for adding several
188 parameters without gradient specification
189 """
190 [self.link_parameter(p) for p in parameters]
191
193 """
194 :param param: param object to remove from being a parameter of this parameterized object.
195 """
196 if not param in self.parameters:
197 try:
198 raise HierarchyError("{} does not belong to this object {}, remove parameters directly from their respective parents".format(param._short(), self.name))
199 except AttributeError:
200 raise HierarchyError("{} does not seem to be a parameter, remove parameters directly from their respective parents".format(str(param)))
201
202 start = sum([p.size for p in self.parameters[:param._parent_index_]])
203 self.size -= param.size
204 del self.parameters[param._parent_index_]
205 self._remove_parameter_name(param)
206
207
208 param._disconnect_parent()
209 param.remove_observer(self, self._pass_through_notify_observers)
210 for name, iop in self._index_operations.items():
211 iop.shift_left(start, param.size)
212
213 self._connect_parameters()
214 self._notify_parent_change()
215
216 parent = self._parent_
217 while parent is not None:
218 parent.size -= param.size
219 parent = parent._parent_
220
221 self._highest_parent_._connect_parameters()
222 self._highest_parent_._connect_fixes()
223 self._highest_parent_._notify_parent_change()
224
226 # connect parameterlist to this parameterized object
227 # This just sets up the right connection for the params objects
228 # to be used as parameters
229 # it also sets the constraints for each parameter to the constraints
230 # of their respective parents
231 self._model_initialized_ = True
232
233 if not hasattr(self, "parameters") or len(self.parameters) < 1:
234 # no parameters for this class
235 return
236
237 old_size = 0
238 self._param_slices_ = []
239 for i, p in enumerate(self.parameters):
240 if not p.param_array.flags['C_CONTIGUOUS']:# getattr(p, 'shape', None) != getattr(p, '_realshape_', None):
241 raise ValueError("""
242 Have you added an additional dimension to a Param object?
243
244 p[:,None], where p is of type Param does not work
245 and is expected to fail! Try increasing the
246 dimensionality of the param array before making
247 a Param out of it:
248 p = Param("<name>", array[:,None])
249
250 Otherwise this should not happen!
251 Please write an email to the developers with the code,
252 which reproduces this error.
253 All parameter arrays must be C_CONTIGUOUS
254 """)
255
256 p._parent_ = self
257 p._parent_index_ = i
258
259 pslice = slice(old_size, old_size + p.size)
260
261 # first connect all children
262 p._propagate_param_grad(self.param_array[pslice], self.gradient_full[pslice])
263
264 # then connect children to self
265 self.param_array[pslice] = p.param_array.flat # , requirements=['C', 'W']).ravel(order='C')
266 self.gradient_full[pslice] = p.gradient_full.flat # , requirements=['C', 'W']).ravel(order='C')
267
268 p.param_array.data = self.param_array[pslice].data
269 p.gradient_full.data = self.gradient_full[pslice].data
270
271 self._param_slices_.append(pslice)
272
273 self._add_parameter_name(p)
274 old_size += p.size
275
276 #===========================================================================
277 # Get/set parameters:
278 #===========================================================================
280 """
281 create a list of parameters, matching regular expression regexp
282 """
283 if not isinstance(regexp, _pattern_type): regexp = compile(regexp)
284 found_params = []
285 def visit(innerself, regexp):
286 if (innerself is not self) and regexp.match(innerself.hierarchy_name().partition('.')[2]):
287 found_params.append(innerself)
288 self.traverse(visit, regexp)
289 return found_params
290
292 if isinstance(name, (int, slice, tuple, np.ndarray)):
293 return self.param_array[name]
294 else:
295 paramlist = self.grep_param_names(name)
296 if len(paramlist) < 1: raise AttributeError(name)
297 if len(paramlist) == 1:
298 #if isinstance(paramlist[-1], Parameterized) and paramlist[-1].size > 0:
299 # paramlist = paramlist[-1].flattened_parameters
300 # if len(paramlist) != 1:
301 # return ParamConcatenation(paramlist)
302 return paramlist[-1]
303
304 from .param import ParamConcatenation
305 return ParamConcatenation(paramlist)
306
308 if not self._model_initialized_:
309 raise AttributeError("""Model is not initialized, this change will only be reflected after initialization if in leaf.
310
311 If you are loading a model, set updates off, then initialize, then set the values, then update the model to be fully initialized:
312 >>> m.update_model(False)
313 >>> m.initialize_parameter()
314 >>> m[:] = loaded_parameters
315 >>> m.update_model(True)
316 """)
317 if value is None:
318 return # nothing to do here
319 if isinstance(name, (slice, tuple, np.ndarray)):
320 try:
321 self.param_array[name] = value
322 except:
323 raise ValueError("Setting by slice or index only allowed with array-like")
324 self.trigger_update()
325 else:
326 param = self.__getitem__(name, paramlist)
327 param[:] = value
328
330 # override the default behaviour, if setting a param, so broadcasting can by used
331 if hasattr(self, "parameters"):
332 pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
333 if name in pnames:
334 param = self.parameters[pnames.index(name)]
335 param[:] = val; return
336 return object.__setattr__(self, name, val);
337
338 #===========================================================================
339 # Pickling
340 #===========================================================================
342 super(Parameterized, self).__setstate__(state)
343 self._connect_parameters()
344 self._connect_fixes()
345 self._notify_parent_change()
346 self.parameters_changed()
347 return self
348
350 if memo is None:
351 memo = {}
352 memo[id(self.optimizer_array)] = None # and param_array
353 memo[id(self.param_array)] = None # and param_array
354 copy = super(Parameterized, self).copy(memo)
355 copy._connect_parameters()
356 copy._connect_fixes()
357 copy._notify_parent_change()
358 return copy
359
360 #===========================================================================
361 # Printing:
362 #===========================================================================
364 return self.hierarchy_name()
365 @property
368
370 props = []
371 for p in self.parameters:
372 props.extend(p.get_property_string(propname))
373 return props
374
375 @property
378
380 """Representation of the parameters in html for notebook display."""
381 name = adjust_name_for_printing(self.name) + "."
382 names = self.parameter_names()
383 desc = self._description_str
384 iops = OrderedDict()
385 for opname in self._index_operations:
386 iop = []
387 for p in self.parameters:
388 iop.extend(p.get_property_string(opname))
389 iops[opname] = iop
390
391 format_spec = self._format_spec(name, names, desc, iops, False)
392 to_print = []
393
394 if header:
395 to_print.append("<tr><th><b>" + '</b></th><th><b>'.join(format_spec).format(name=name, desc='value', **dict((name, name) for name in iops)) + "</b></th></tr>")
396
397 format_spec = "<tr><td class=tg-left>" + format_spec[0] + '</td><td class=tg-right>' + format_spec[1] + '</td><td class=tg-center>' + '</td><td class=tg-center>'.join(format_spec[2:]) + "</td></tr>"
398 for i in range(len(names)):
399 to_print.append(format_spec.format(name=names[i], desc=desc[i], **dict((name, iops[name][i]) for name in iops)))
400
401 style = """<style type="text/css">
402 .tg {font-family:"Courier New", Courier, monospace !important;padding:2px 3px;word-break:normal;border-collapse:collapse;border-spacing:0;border-color:#DCDCDC;margin:0px auto;width:100%;}
403 .tg td{font-family:"Courier New", Courier, monospace !important;font-weight:bold;color:#444;background-color:#F7FDFA;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}
404 .tg th{font-family:"Courier New", Courier, monospace !important;font-weight:normal;color:#fff;background-color:#26ADE4;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#DCDCDC;}
405 .tg .tg-left{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:left;}
406 .tg .tg-center{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:center;}
407 .tg .tg-right{font-family:"Courier New", Courier, monospace !important;font-weight:normal;text-align:right;}
408 </style>"""
409 return style + '\n' + '<table class="tg">' + '\n'.join(to_print) + '\n</table>'
410
412 nl = max([len(str(x)) for x in names + [name]])
413 sl = max([len(str(x)) for x in desc + ["value"]])
414
415 lls = [reduce(lambda a,b: max(a, len(b)), iops[opname], len(opname)) for opname in iops]
416
417 if VT100:
418 format_spec = [" \033[1m{{name!s:<{0}}}\033[0;0m".format(nl),"{{desc!s:>{0}}}".format(sl)]
419 else:
420 format_spec = [" {{name!s:<{0}}}".format(nl),"{{desc!s:>{0}}}".format(sl)]
421
422 for opname, l in zip(iops, lls):
423 f = '{{{1}!s:^{0}}}'.format(l, opname)
424 format_spec.append(f)
425
426 return format_spec
427
429 name = adjust_name_for_printing(self.name) + "."
430 names = self.parameter_names(adjust_for_printing=True)
431 desc = self._description_str
432 iops = OrderedDict()
433 for opname in self._index_operations:
434 iops[opname] = self.get_property_string(opname)
435
436 format_spec = ' | '.join(self._format_spec(name, names, desc, iops, VT100))
437
438 to_print = []
439
440 if header:
441 to_print.append(format_spec.format(name=name, desc='value', **dict((name, name) for name in iops)))
442
443 for i in range(len(names)):
444 to_print.append(format_spec.format(name=names[i], desc=desc[i], **dict((name, iops[name][i]) for name in iops)))
445 return '\n'.join(to_print)
446
448 """
449 Build a pydot representation of this model. This needs pydot installed.
450
451 Example Usage::
452
453 np.random.seed(1000)
454 X = np.random.normal(0,1,(20,2))
455 beta = np.random.uniform(0,1,(2,1))
456 Y = X.dot(beta)
457 m = RidgeRegression(X, Y)
458 G = m.build_pydot()
459 G.write_png('example_hierarchy_layout.png')
460
461 The output looks like:
462
463 .. image:: ./example_hierarchy_layout.png
464
465 Rectangles are parameterized objects (nodes or leafs of hierarchy).
466
467 Trapezoids are param objects, which represent the arrays for parameters.
468
469 Black arrows show parameter hierarchical dependence. The arrow points
470 from parents towards children.
471
472 Orange arrows show the observer pattern. Self references (here) are
473 the references to the call to parameters changed and references upwards
474 are the references to tell the parents they need to update.
475 """
476 import pydot # @UnresolvedImport
477 iamroot = False
478 if G is None:
479 G = pydot.Dot(graph_type='digraph', bgcolor=None)
480 iamroot=True
481 node = pydot.Node(id(self), shape='box', label=self.name)#, color='white')
482
483 G.add_node(node)
484 for child in self.parameters:
485 child_node = child.build_pydot(G)
486 G.add_edge(pydot.Edge(node, child_node))#, color='white'))
487
488 for _, o, _ in self.observers:
489 label = o.name if hasattr(o, 'name') else str(o)
490 observed_node = pydot.Node(id(o), label=label)
491 if str(id(o)) not in G.obj_dict['nodes']:
492 G.add_node(observed_node)
493 edge = pydot.Edge(str(id(self)), str(id(o)), color='darkorange2', arrowhead='vee')
494 G.add_edge(edge)
495
496 if iamroot:
497 return G
498 return node
499
| Home | Trees | Indices | Help |
|---|
| Generated by Epydoc 3.0.1 on Tue Jul 4 12:00:20 2017 | http://epydoc.sourceforge.net |