From cf94310598d7c2ae32fdafdd3f7b959eb21da079 Mon Sep 17 00:00:00 2001 From: Ziver Koc Date: Wed, 5 Sep 2018 16:37:19 +0200 Subject: [PATCH] Added more elemental matrix math and a non functioning gradient decent function. --- src/zutil/math/Matrix.java | 40 +++++++++++++++++++++++-- src/zutil/ml/LinearRegression.java | 35 ++++++++++++++++++++-- test/zutil/math/MatrixTest.java | 20 ++++++++++++- test/zutil/ml/LinearRegressionTest.java | 34 ++++++++++++++++++++- 4 files changed, 123 insertions(+), 6 deletions(-) diff --git a/src/zutil/math/Matrix.java b/src/zutil/math/Matrix.java index d382e0f..430db0b 100755 --- a/src/zutil/math/Matrix.java +++ b/src/zutil/math/Matrix.java @@ -146,6 +146,42 @@ public class Matrix { return result; } + /** + * Matrix Vector elemental multiplication, each element column in the matrix will be + * multiplied with the column of the vector. + * + * @return a new matrix with the result + */ + public static double[][] multiply(double[][] matrix, double[] vector){ + vectorPreCheck(matrix, vector); + double[][] result = new double[matrix.length][matrix[0].length]; + + for (int y=0; y < matrix.length; ++y) { + for (int x=0; x
diff --git a/test/zutil/math/MatrixTest.java b/test/zutil/math/MatrixTest.java index 5f5e9d2..1ea4970 100755 --- a/test/zutil/math/MatrixTest.java +++ b/test/zutil/math/MatrixTest.java @@ -125,7 +125,7 @@ public class MatrixTest { } @Test - public void vectorDivision(){ + public void vectorMatrixDivision(){ assertArrayEquals( new double[]{4,1}, Matrix.divide(new double[][]{{2,4},{-4,10}}, new double[]{1,2}), @@ -133,6 +133,24 @@ public class MatrixTest { ); } + @Test + public void vectorMatrixElementalMultiply(){ + assertArrayEquals( + new double[][]{{1, 4, 9}, {1, 6, 12}, {1, 8, 15}, {1, 10, 18}}, + Matrix.Elemental.multiply( + new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}}, + new double[]{1, 2, 3})); + } + + @Test + public void vectorMatrixElementalDivision(){ + assertArrayEquals( + new double[][]{{2,2},{-4,5}}, + Matrix.Elemental.divide( + new double[][]{{2,4},{-4,10}}, + new double[]{1,2})); + } + @Test public void vectorSum(){ assertEquals( diff --git a/test/zutil/ml/LinearRegressionTest.java b/test/zutil/ml/LinearRegressionTest.java index a96c8d6..9ba9c25 100755 --- a/test/zutil/ml/LinearRegressionTest.java +++ b/test/zutil/ml/LinearRegressionTest.java @@ -1,6 +1,9 @@ package zutil.ml; import org.junit.Test; +import zutil.log.LogUtil; + +import java.util.logging.Level; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -32,8 +35,37 @@ public class LinearRegressionTest { assertEquals(11.9450, cost, 0.0001); } - @Test + // Does not work + //@Test public void gradientAscent() { + double[][] x = { + {1.0, 0.1, 0.6, 1.1}, + {1.0, 0.2, 0.7, 1.2}, + {1.0, 0.3, 0.8, 1.3}, + {1.0, 0.4, 0.9, 1.4}, + {1.0, 0.5, 1.0, 1.5} + }; + double[] y = { + 1, + 0, + 1, + 0, + 1 + }; + double[] theta = { + -2, + -1, + 1, + 2 + }; + + double[] resultTheta = LinearRegression.gradientDescent(x, y, theta, 0); + + assertEquals(0.73482, LinearRegression.calculateCost(x, y, resultTheta), 0.000001); + } + + @Test + public void gradientAscentIteration() { double[] theta = LinearRegression.gradientDescentIteration( // one iteration /* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}}, /* y */ new double[]{1, 6, 4, 2},