1 '''
2 Created on 4 Sep 2015
3
4 @author: maxz
5 '''
6 import unittest
7 from ..caching import Cacher
8 from pickle import PickleError
9 from ..core.observable_array import ObsAr
10 import numpy as np
11 from paramz.caching import Cache_this
15 opcalls = [0, 0, 0]
16 class O(object):
17 @Cache_this(ignore_args=[0, 3], force_kwargs=('force',))
18 def __call__(self, x, y, ignore_this, force=False):
19 """Documentation"""
20 opcalls[0] += 1
21 opcalls[1] = ignore_this
22 if force is not False:
23 opcalls[2] = force
24 return x + y
25 self.opcalls = opcalls
26 self.O = O
27 self.cached = self.O()
28
30 class O(object):
31 def __init__(self):
32 self._test = Cacher(self.test)
33 self._test_other = Cacher(self.other_method)
34 def test(self, x, y):
35 return x * y
36 def other_method(self, x, y):
37 return x + y
38
39 a = ObsAr(np.random.normal(0, 1, (2, 1)))
40 b = ObsAr(np.random.normal(0, 1, (2, 1)))
41
42 o = O()
43
44 ab = o._test(a, b)
45 self.assertIs(ab, o._test(a, b))
46 aba = o._test(ab, a)
47 self.assertIsNot(ab, aba)
48
49 self.assertIs(ab, o._test(a, b))
50
51 a[0] = 15
52 self.assertIsNot(ab, o._test(a, b))
53 self.assertIsNot(ab, o._test(ab, a))
54
55 self.assertIs(o.cache[o.test], o._test)
56 self.assertIs(o.cache[o.other_method], o._test_other)
57
58
60 a = ObsAr(np.random.normal(0, 1, (2, 1)))
61 b = ObsAr(np.random.normal(0, 1, (2, 1)))
62
63 ab = self.cached(a, b, 'ignored')
64 self.assertEqual(self.opcalls[0], 1)
65 self.assertEqual(self.opcalls[1], 'ignored')
66 self.assertIs(ab, self.cached(a, b, 2))
67 self.assertEqual(self.opcalls[1], 'ignored')
68 self.assertEqual(self.opcalls[0], 1)
69 abnew = self.cached(a, b, 3, force='given')
70 self.assertEqual(self.opcalls[0], 2)
71 self.assertEqual(self.opcalls[1], 3)
72 self.assertEqual(self.opcalls[2], 'given')
73 self.assertIsNot(ab, abnew)
74 np.testing.assert_array_equal(abnew, ab)
75 abforced = self.cached(a, b, 4, force='given2')
76 self.assertEqual(self.opcalls[0], 3)
77 self.assertEqual(self.opcalls[1], 4)
78 self.assertEqual(self.opcalls[2], 'given2')
79 np.testing.assert_array_equal(abnew, abforced)
80
82 try:
83 from inspect import getfullargspec, getdoc
84 self.assertEqual(getfullargspec(self.cached.__call__), getfullargspec(self.O.__call__))
85 self.assertEqual(getdoc(self.cached.__call__), getdoc(self.O.__call__))
86 except ImportError:
87 try:
88 from inspect import signature, getdoc
89 print(signature(self.cached.__call__), signature(self.O.__call__))
90 self.assertEqual(signature(self.cached.__call__), signature(self.O.__call__))
91 self.assertEqual(getdoc(self.cached.__call__), getdoc(self.O.__call__))
92 except ImportError:
93 from inspect import getargspec, getdoc
94 self.assertEqual(getargspec(self.cached.__call__), getargspec(self.O.__call__))
95 self.assertEqual(getdoc(self.cached.__call__), getdoc(self.O.__call__))
96
98 a = ObsAr(np.random.normal(0, 1, (2, 1)))
99 b = ObsAr(np.random.normal(0, 1, (2, 1)))
100 ab1 = self.cached(a, b, 'ignored')
101 ab2 = self.cached(a, b, 2)
102 self.assertIs(ab1, ab2)
103 self.cached.cache.disable_caching()
104 ab3 = self.cached(a, b, 'ignored')
105 self.assertIsNot(ab1, ab3)
106
107 self.cached.cache.enable_caching()
108 ab4 = self.cached(a, b, 'ignored')
109 self.assertIsNot(ab1, ab4)
110 self.assertIsNot(ab3, ab4)
111 self.assertIs(ab4, self.cached(a, b, 'its ignored'))
112
114 a = ObsAr(np.random.normal(0, 1, (2, 1)))
115 b = ObsAr(np.random.normal(0, 1, (2, 1)))
116 ab1 = self.cached(a, b, 'ignored')
117 ab2 = self.cached(a, b, 2)
118 self.assertIs(ab1, ab2)
119 self.cached.cache.reset()
120 self.assertIsNot(ab1, self.cached(a, b, 2))
121
122
123 -class Test(unittest.TestCase):
125 def op(x, *args):
126 return (x,) + args
127 self.cached = Cacher(op, 2)
128
130 self.assertRaises(PickleError, self.cached.__getstate__)
131 self.assertRaises(PickleError, self.cached.__setstate__)
132
134 tmp = self.cached.__deepcopy__()
135 assert(tmp.operation is self.cached.operation)
136 self.assertEqual(tmp.limit, self.cached.limit)
137
139 opcalls = [0]
140 def op(x, y):
141 opcalls[0] += 1
142 return x + y
143
144 cache = Cacher(op, 1)
145
146 ins = [1, 2]
147 b = cache(*ins)
148 self.assertIs(cache(*ins), b)
149 self.assertEqual(opcalls[0], 1)
150 self.assertIn(ins, cache.cached_inputs.values())
151
152 self.assertRaises(TypeError, cache, 'this does not work', 2)
153
154
155 self.assertDictEqual(cache.cached_input_ids, {},)
156 self.assertDictEqual(cache.cached_outputs, {},)
157 self.assertDictEqual(cache.inputs_changed, {},)
158
160 i = "printing the cached value"
161 print(self.cached(i))
162 self.assertIn((i,), self.cached.cached_outputs.values())
163 print(self.cached(i))
164 self.assertIn((i,), self.cached.cached_outputs.values())
165 self.assertEqual(len(self.cached.cached_outputs.values()), 1)
166
168 class O(object):
169 "not cachable"
170 o = O()
171 c = self.cached(o, 1)
172
173 self.assertIs(c[0], o)
174
175 self.assertEqual(len(self.cached.cached_outputs.values()), 0)
176
178
179 opcalls = [0]
180 def op(x, y, force=False):
181 opcalls[0] += 1
182 return x + y
183
184 cache = Cacher(op, 1, force_kwargs=('force',))
185
186 a = ObsAr(np.random.normal(0, 1, (2, 1)))
187 b = ObsAr(np.random.normal(0, 1, (2, 1)))
188
189 ab = cache(a, b)
190 self.assertEqual(opcalls[0], 1)
191 self.assertIs(ab, cache(a, b))
192 self.assertIsNot(ab, cache(a, b, force='given'))
193 self.assertEqual(opcalls[0], 2)
194
196
197 opcalls = [0]
198 a = ObsAr(np.random.randn())
199 b = ObsAr(np.random.randn())
200 def op(x, y, force=False):
201 opcalls[0] += 1
202 print(opcalls)
203 if opcalls[0] == 2:
204 raise RuntimeError("Oh Noooo, something went wrong here, what is the cacher to do????")
205 return x + y
206
207 cache = Cacher(op, 1, force_kwargs=('force',))
208 ab = cache(a, b)
209 self.assertEqual(ab, a + b)
210 self.assertIs(ab, cache(a,b))
211 self.assertEqual(opcalls[0], 1)
212 self.assertRaises(RuntimeError, cache, b, a)
213 self.assertEqual(len(cache.cached_inputs), 0)
214
215
216 ab = cache(a, b)
217 self.assertIs(ab, cache(a, b))
218
220 i = 1234
221 print(self.cached(i))
222 self.assertIn((i,), self.cached.cached_outputs.values())
223 print(self.cached(i))
224 self.assertIn((i,), self.cached.cached_outputs.values())
225 self.assertEqual(len(self.cached.cached_outputs.values()), 1)
226
228 i = ObsAr(np.random.normal(0, 1, (10, 3)))
229
230 self.cached(i)
231 _inputs = self.cached.combine_inputs((i,), {}, ())
232 id_ = self.cached.prepare_cache_id(_inputs)
233
234 self.assertIs(i, self.cached.cached_inputs[id_][0])
235 self.assertIs(i, self.cached(i)[0])
236
237 i[0] = 10
238
239 self.assertTrue(self.cached.inputs_changed[id_])
240
241 self.cached(i)
242
243 self.assertFalse(self.cached.inputs_changed[id_])
244
246 i = ObsAr(np.random.normal(0, 1, (10, 3)))
247 self.cached(i, 1234)
248
249 self.cached(i, 1234)
250 _inputs = self.cached.combine_inputs((i, 1234), {}, ())
251 id_ = self.cached.prepare_cache_id(_inputs)
252
253 self.assertIs(i, self.cached.cached_inputs[id_][0])
254 self.assertIs(i, self.cached(i, 1234)[0])
255 self.assertIs(1234, self.cached(i, 1234)[1])
256
257 i[0] = 10
258
259 self.assertTrue(self.cached.inputs_changed[id_])
260
261 old_c = self.cached(i, 1234)
262
263 self.assertFalse(self.cached.inputs_changed[id_])
264
265
266 old_c1235 = self.cached(i, 1235)
267 self.assertIs(old_c1235, self.cached(i, 1235))
268 self.assertEqual(len(self.cached.cached_inputs), 2)
269
270
271 another = self.cached(i, "another")
272 self.assertIs(self.cached(i, "another"), another)
273 self.assertEqual(len(self.cached.cached_inputs), 2)
274
275 self.assertIsNot(old_c, self.cached(i, 1234))
276 self.assertEqual(len(self.cached.cached_inputs), 2)
277
278 self.assertFalse(self.cached.inputs_changed[id_])
279 i[4] = 3
280 self.assertTrue(self.cached.inputs_changed[id_])
281
282 self.assertIsNot(old_c1235, self.cached(i, 1235))
283
284 self.assertIsNot(self.cached(i, "another"), another)
285
287
288 a = ObsAr(np.random.normal(0, 1, (2, 1)))
289 b = ObsAr(np.random.normal(0, 1, (2, 1)))
290 c = ObsAr(np.random.normal(0, 1, (2, 1)))
291
292
293 def op(x, y):
294 return x + y
295
296 cache = Cacher(op, 2)
297
298
299 _inputs = cache.combine_inputs((a, b), {}, ())
300 id_ab = cache.prepare_cache_id(_inputs)
301 _inputs = cache.combine_inputs((a, c), {}, ())
302 id_ac = cache.prepare_cache_id(_inputs)
303 _inputs = cache.combine_inputs((b, c), {}, ())
304 id_bc = cache.prepare_cache_id(_inputs)
305
306
307
308 ab = cache(a, b)
309 np.testing.assert_array_equal(ab, a + b)
310 self.assertIs(cache(a, b), ab)
311 self.assertFalse(cache.inputs_changed[id_ab])
312 a[:] += np.random.normal(0, 1, (2, 1))
313 self.assertTrue(cache.inputs_changed[id_ab])
314 abnew = cache(a, b)
315
316 _inputs = cache.combine_inputs((a, b), {}, ())
317 id_ab = cache.prepare_cache_id(_inputs)
318 self.assertFalse(cache.inputs_changed[id_ab])
319 self.assertIsNot(abnew, ab)
320 self.assertIs(cache(a, b), abnew)
321 np.testing.assert_array_equal(cache(a, b), a + b)
322
323
324 np.testing.assert_array_equal(cache(a, c), a + c)
325 np.testing.assert_array_equal(cache(b, c), b + c)
326
327 ab = cache(a, b)
328 self.assertIsNot(abnew, ab)
329 self.assertIs(ab, cache(a, b))
330 self.assertIn(id_bc, cache.inputs_changed)
331 self.assertIn(id_ab, cache.inputs_changed)
332 self.assertNotIn(id_ac, cache.inputs_changed)
333
336
337 if __name__ == "__main__":
338
339 unittest.main()
340