Package paramz :: Package tests :: Module examples_tests
[hide private]
[frames] | no frames]

Source Code for Module paramz.tests.examples_tests

 1  #=============================================================================== 
 2  # Copyright (c) 2015, Max Zwiessele 
 3  # All rights reserved. 
 4  # 
 5  # Redistribution and use in source and binary forms, with or without 
 6  # modification, are permitted provided that the following conditions are met: 
 7  # 
 8  # * Redistributions of source code must retain the above copyright notice, this 
 9  #   list of conditions and the following disclaimer. 
10  # 
11  # * Redistributions in binary form must reproduce the above copyright notice, 
12  #   this list of conditions and the following disclaimer in the documentation 
13  #   and/or other materials provided with the distribution. 
14  # 
15  # * Neither the name of paramax nor the names of its 
16  #   contributors may be used to endorse or promote products derived from 
17  #   this software without specific prior written permission. 
18  # 
19  # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
20  # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
21  # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
22  # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
23  # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
24  # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
25  # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
26  # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
27  # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
28  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
29  #=============================================================================== 
30   
31  import unittest 
32  import numpy as np 
33  from ..examples import RidgeRegression, Lasso, Polynomial 
34   
35 -class Test2D(unittest.TestCase):
36
37 - def testRidgeRegression(self):
38 np.random.seed(1000) 39 X = np.random.normal(0,1,(20,2)) 40 beta = np.random.uniform(0,1,(2,1)) 41 Y = X.dot(beta) 42 #Y += np.random.normal(0, .001, Y.shape) 43 m = RidgeRegression(X, Y) 44 m.regularizer.lambda_ = 0.00001 45 self.assertTrue(m.checkgrad()) 46 m.optimize('scg', gtol=0, ftol=0, xtol=0,max_iters=10) 47 m.optimize(max_iters=10) 48 np.testing.assert_array_almost_equal(m.regularizer.weights[1], beta[:,0], 4) 49 np.testing.assert_array_almost_equal(m.regularizer.weights[0], [0,0], 4) 50 np.testing.assert_array_almost_equal(m.gradient, np.zeros(m.weights.size), 4) 51 52 xpred = np.repeat(np.linspace(0,1,50)[:,None], 2, axis=1) 53 xpred[:, 1] = xpred[::-1, 1] 54 phi = m.phi(xpred) 55 np.testing.assert_array_almost_equal(phi[0], np.zeros_like(xpred), 4) 56 np.testing.assert_array_almost_equal(phi[1], xpred*beta.T) 57 for d in range(2): 58 phid = m.phi(xpred, [d]) 59 np.testing.assert_array_equal(phi[d], phid[0]) 60 61 ypred = m.predict(xpred) 62 np.testing.assert_array_almost_equal(ypred, xpred.dot(beta))
63
64 - def testLassoRegression(self):
65 np.random.seed(12345) 66 X = np.random.uniform(0,1,(20,2)) 67 beta = np.random.normal(0,1,(2,1)) 68 Y = X.dot(beta) 69 #Y += np.random.normal(0, .001, Y.shape) 70 71 m = RidgeRegression(X, Y, regularizer=Lasso(.00001), basis=Polynomial(1)) 72 self.assertTrue(m.checkgrad()) 73 m.optimize(max_iters=10) 74 np.testing.assert_array_almost_equal(m.regularizer.weights[1], beta[:,0], 3) 75 np.testing.assert_array_almost_equal(m.regularizer.weights[0], [0,0], 3) 76 np.testing.assert_array_almost_equal(m.gradient, np.zeros(m.weights.size), 3)
77 78 # m = RidgeRegression(X, Y, regularizer=Lasso(.00001, np.ones(X.shape[1]))) 79 # self.assertTrue(m.checkgrad()) 80 # m.optimize() 81 # np.testing.assert_array_almost_equal(m.regularizer.weights[1], beta, 4) 82 # np.testing.assert_array_almost_equal(m.gradient, np.zeros(m.weights.shape[0]), 4) 83 84 if __name__ == "__main__": 85 #import sys;sys.argv = ['', 'Test.testRidgeRegression'] 86 unittest.main() 87