Package paramz :: Package tests :: Module pickle_tests
[hide private]
[frames] | no frames]

Source Code for Module paramz.tests.pickle_tests

  1  ''' 
  2  Created on 13 Mar 2014 
  3   
  4  @author: maxz 
  5  ''' 
  6  import unittest, pickle, tempfile, os, paramz 
  7  import numpy as np 
  8  from ..core.index_operations import ParameterIndexOperations, ParameterIndexOperationsView 
  9  from ..core.observable_array import ObsAr 
 10  from paramz.transformations import Exponent, Logexp 
 11  from ..parameterized import Parameterized 
 12  from ..param import Param 
 13   
14 -class ListDictTestCase(unittest.TestCase):
15 - def assertListDictEquals(self, d1, d2, msg=None):
16 #py3 fix 17 #for k,v in d1.iteritems(): 18 for k,v in d1.items(): 19 self.assertListEqual(list(v), list(d2[k]), msg)
20 - def assertArrayListEquals(self, l1, l2):
21 for a1, a2 in zip(l1,l2): 22 np.testing.assert_array_equal(a1, a2)
23
24 -class Test(ListDictTestCase):
26 pio = ParameterIndexOperations(dict(test1=np.array([4,3,1,6,4]), test2=np.r_[2:130])) 27 piov = ParameterIndexOperationsView(pio, 20, 250) 28 #py3 fix 29 #self.assertListDictEquals(dict(piov.items()), dict(piov.copy().iteritems())) 30 self.assertListDictEquals(dict(piov.items()), dict(piov.copy().items())) 31 32 #py3 fix 33 #self.assertListDictEquals(dict(pio.iteritems()), dict(pio.copy().items())) 34 self.assertListDictEquals(dict(pio.items()), dict(pio.copy().items())) 35 36 self.assertArrayListEquals(pio.copy().indices(), pio.indices()) 37 self.assertArrayListEquals(piov.copy().indices(), piov.indices()) 38 39 with tempfile.TemporaryFile('w+b') as f: 40 pickle.dump(pio, f) 41 f.seek(0) 42 pio2 = pickle.load(f) 43 self.assertListDictEquals(pio._properties, pio2._properties) 44 45 with tempfile.TemporaryFile('w+b') as f: 46 pickle.dump(piov, f) 47 f.seek(0) 48 pio2 = paramz.load(f) 49 #py3 fix 50 #self.assertListDictEquals(dict(piov.items()), dict(pio2.iteritems())) 51 self.assertListDictEquals(dict(piov.items()), dict(pio2.items()))
52
53 - def test_param(self):
54 param = Param('test', np.arange(4*2).reshape(4,2)) 55 param[0].constrain_positive() 56 param[1].fix() 57 pcopy = param.copy() 58 self.assertListEqual(param.tolist(), pcopy.tolist()) 59 self.assertListEqual(str(param).split('\n'), str(pcopy).split('\n')) 60 self.assertIsNot(param, pcopy) 61 with tempfile.TemporaryFile('w+b') as f: 62 pickle.dump(param, f) 63 f.seek(0) 64 pcopy = paramz.load(f) 65 self.assertListEqual(param.tolist(), pcopy.tolist()) 66 self.assertSequenceEqual(str(param), str(pcopy))
67
68 - def test_observable_array(self):
69 obs = ObsAr(np.arange(4*2).reshape(4,2)) 70 pcopy = obs.copy() 71 self.assertListEqual(obs.tolist(), pcopy.tolist()) 72 tmpfile = ''.join(map(str, np.random.randint(10, size=20))) 73 try: 74 obs.pickle(tmpfile) 75 pcopy = paramz.load(tmpfile) 76 except: 77 raise 78 finally: 79 os.remove(tmpfile) 80 self.assertListEqual(obs.tolist(), pcopy.tolist()) 81 self.assertSequenceEqual(str(obs), str(pcopy))
82
83 - def test_parameterized(self):
84 par = Parameterized('parameterized') 85 p2 = Parameterized('rbf') 86 p2.p1 = Param('lengthscale', np.random.uniform(0.1,.5,3), Exponent()) 87 p2.link_parameter(p2.p1) 88 par.p1 = p2 89 par.p2 = Param('linear', np.random.uniform(0.1, .5, 2), Logexp()) 90 par.link_parameters(par.p1, par.p2) 91 92 par.gradient = 10 93 par.randomize() 94 pcopy = par.copy() 95 self.assertIsInstance(pcopy.constraints, ParameterIndexOperations) 96 self.assertIsInstance(pcopy.rbf.constraints, ParameterIndexOperationsView) 97 self.assertIs(pcopy.constraints, pcopy.rbf.constraints._param_index_ops) 98 self.assertIs(pcopy.constraints, pcopy.rbf.lengthscale.constraints._param_index_ops) 99 self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) 100 pcopy.gradient = 10 # gradient does not get copied anymore 101 self.assertListEqual(par.gradient_full.tolist(), pcopy.gradient_full.tolist()) 102 self.assertSequenceEqual(str(par), str(pcopy)) 103 self.assertIsNot(par.param_array, pcopy.param_array) 104 self.assertIsNot(par.gradient_full, pcopy.gradient_full) 105 with tempfile.TemporaryFile('w+b') as f: 106 par.pickle(f) 107 f.seek(0) 108 pcopy = paramz.load(f) 109 self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) 110 pcopy.gradient = 10 111 np.testing.assert_allclose(par.linear.gradient_full, pcopy.linear.gradient_full) 112 np.testing.assert_allclose(pcopy.linear.gradient_full, 10) 113 self.assertSequenceEqual(str(par), str(pcopy))
114 115
116 - def _callback(self, what, which):
117 what.count += 1
118 119 120 if __name__ == "__main__": 121 #import sys;sys.argv = ['', 'Test.test_parameter_index_operations'] 122 unittest.main() 123