Package paramz :: Package tests :: Module observable_tests
[hide private]
[frames] | no frames]

Source Code for Module paramz.tests.observable_tests

  1  # Copyright (c) 2014, Max Zwiessele 
  2  # Licensed under the BSD 3-clause license (see LICENSE.txt) 
  3  import unittest 
  4  from ..core.observable_array import ObsAr 
  5  from ..parameterized import Parameterized 
  6  from ..param import Param 
  7  import numpy as np 
  8   
  9  # One trigger in init 
 10  _trigger_start = -1 
 11   
12 -class ParamTestParent(Parameterized):
13 parent_changed_count = _trigger_start
14 - def parameters_changed(self):
15 self.parent_changed_count += 1
16
17 -class ParameterizedTest(Parameterized):
18 # One trigger after initialization 19 params_changed_count = _trigger_start
20 - def parameters_changed(self):
21 self.params_changed_count += 1
22
23 -class TestMisc(unittest.TestCase):
24 - def test_casting(self):
25 ints = np.array(range(10)) 26 self.assertEqual(ints.dtype, np.int_) 27 floats = np.arange(0,5,.5) 28 self.assertEqual(floats.dtype, np.float_) 29 strings = np.array(list('testing')) 30 self.assertEqual(strings.dtype.type, np.str_) 31 32 self.assertEqual(ObsAr(ints).dtype, np.float_) 33 self.assertEqual(ObsAr(floats).dtype, np.float_) 34 self.assertEqual(ObsAr(strings).dtype.type, np.str_)
35 36
37 -class Test(unittest.TestCase):
38
39 - def setUp(self):
40 self.parent = ParamTestParent('test parent') 41 self.par = ParameterizedTest('test model') 42 self.par2 = ParameterizedTest('test model 2') 43 self.p = Param('test parameter', np.random.normal(1,2,(10,3))) 44 45 self.par.link_parameter(self.p) 46 self.par.link_parameter(Param('test1', np.random.normal(0,1,(1,)))) 47 self.par.link_parameter(Param('test2', np.random.normal(0,1,(1,)))) 48 49 self.par2.link_parameter(Param('par2 test1', np.random.normal(0,1,(1,)))) 50 self.par2.link_parameter(Param('par2 test2', np.random.normal(0,1,(1,)))) 51 52 self.parent.link_parameter(self.par) 53 self.parent.link_parameter(self.par2) 54 55 self._observer_triggered = None 56 self._trigger_count = 0 57 self._first = None 58 self._second = None
59
60 - def _trigger(self, me, which):
61 self._observer_triggered = which 62 self._trigger_count += 1 63 if self._first is not None: 64 self._second = self._trigger 65 else: 66 self._first = self._trigger
67
68 - def _trigger_priority(self, me, which):
69 if self._first is not None: 70 self._second = self._trigger_priority 71 else: 72 self._first = self._trigger_priority
73
74 - def test_observable(self):
75 self.par.add_observer(self, self._trigger, -1) 76 self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet') 77 self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') 78 79 self.p[0,1] = 3 # trigger observers 80 self.assertIs(self._observer_triggered, self.p, 'observer should have triggered') 81 self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') 82 self.assertEqual(self.par.params_changed_count, 1, 'params changed once') 83 self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') 84 85 self.par.remove_observer(self) 86 self.p[0,1] = 4 87 self.assertIs(self._observer_triggered, self.p, 'observer should not have triggered') 88 self.assertEqual(self._trigger_count, 1, 'observer should have triggered once') 89 self.assertEqual(self.par.params_changed_count, 2, 'params changed second') 90 self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') 91 92 self.par.add_observer(self, self._trigger, -1) 93 self.p[0,1] = 4 94 self.assertIs(self._observer_triggered, self.p, 'observer should have triggered') 95 self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') 96 self.assertEqual(self.par.params_changed_count, 3, 'params changed second') 97 self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param') 98 99 self.par.remove_observer(self, self._trigger) 100 self.p[0,1] = 3 101 self.assertIs(self._observer_triggered, self.p, 'observer should not have triggered') 102 self.assertEqual(self._trigger_count, 2, 'observer should have triggered once') 103 self.assertEqual(self.par.params_changed_count, 4, 'params changed second') 104 self.assertEqual(self.par.params_changed_count, self.parent.parent_changed_count, 'parent should be triggered as often as param')
105
106 - def test_set_params(self):
107 self.assertEqual(self.par.params_changed_count, 0, 'no params changed yet') 108 self.par.param_array[:] = 1 109 self.par._trigger_params_changed() 110 self.assertEqual(self.par.params_changed_count, 1, 'now params changed') 111 self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count) 112 113 self.par.param_array[:] = 2 114 self.par._trigger_params_changed() 115 self.assertEqual(self.par.params_changed_count, 2, 'now params changed') 116 self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count) 117 118 self.par.set_updates(False) 119 self.par.param_array[:] = 3 120 self.par._trigger_params_changed() 121 self.assertEqual(self.par.params_changed_count, 2, 'should not have changed, as updates are off') 122 self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count)
123
124 - def test_priority_notify(self):
125 self.assertEqual(self.par.params_changed_count, 0) 126 self.par.notify_observers(0, None) 127 self.assertEqual(self.par.params_changed_count, 1) 128 self.assertEqual(self.parent.parent_changed_count, self.par.params_changed_count) 129 130 self.par.notify_observers(0, -np.inf) 131 self.assertEqual(self.par.params_changed_count, 2) 132 self.assertEqual(self.parent.parent_changed_count, 1)
133
134 - def test_priority(self):
135 self.par.add_observer(self, self._trigger, -1) 136 self.par.add_observer(self, self._trigger_priority, 0) 137 self.par.notify_observers(0) 138 self.assertEqual(self._first, self._trigger_priority, 'priority should be first') 139 self.assertEqual(self._second, self._trigger, 'trigger should be second') 140 141 self.par.remove_observer(self) 142 self._first = self._second = None 143 144 self.par.add_observer(self, self._trigger, 1) 145 self.par.add_observer(self, self._trigger_priority, 0) 146 self.par.notify_observers(0) 147 self.assertEqual(self._first, self._trigger, 'priority should be second') 148 self.assertEqual(self._second, self._trigger_priority, 'priority should be second') 149 150 self._first = self._second = None 151 self.par.change_priority(self, self._trigger, -1) 152 self.par.change_priority(self, self._trigger_priority, 0) 153 self.par.notify_observers(0) 154 self.assertEqual(self._first, self._trigger_priority, 'priority should be first') 155 self.assertEqual(self._second, self._trigger, 'trigger should be second')
156
157 - def testObsAr(self):
158 o = ObsAr(np.random.normal(0,1,(10))) 159 o[3:5] = 5 160 np.testing.assert_array_equal(o[3:5].values, 5)
161 162 if __name__ == "__main__": 163 #import sys;sys.argv = ['', 'Test.testName'] 164 unittest.main() 165