diff --git a/Zutil.iml b/Zutil.iml index a2bde11..5454eb9 100755 --- a/Zutil.iml +++ b/Zutil.iml @@ -45,5 +45,15 @@ + + + + + + + + + + \ No newline at end of file diff --git a/src/zutil/math/Matrix.java b/src/zutil/math/Matrix.java index 9ec3e74..6ca6d76 100755 --- a/src/zutil/math/Matrix.java +++ b/src/zutil/math/Matrix.java @@ -226,7 +226,7 @@ public class Matrix { * Matrix Vector subtraction, each column in the matrix will be subtracted * with the vector. * - * @return a new vector with subtracted elements + * @return a new matrix with subtracted elements */ public static double[][] subtract(double[][] matrix, double[] vector){ vectorPreCheck(matrix, vector); @@ -247,15 +247,13 @@ public class Matrix { * * @return a new vector with the result */ - public static double[][] multiply(double[][] matrix, double[] vector){ + public static double[] multiply(double[][] matrix, double[] vector){ vectorPreCheck(matrix, vector); - double[][] result = new double[matrix.length][1]; + double[] result = new double[matrix.length]; - for (int y=0; y < result.length; ++y) { + for (int y=0; y < matrix.length; ++y) { for (int x=0; x transpose(theta) * x * */ - protected static double[][] calculateHypotesis(double[][] x, double[] theta){ + protected static double[] calculateHypothesis(double[][] x, double[] theta){ return Matrix.multiply(x, theta); } /** - * Linear Regresion cost method. + * Linear Regression cost method. *

* * J(O) = 1 / (2 * m) * Σ { ( h(Xi) - Yi )^2 } @@ -30,10 +30,11 @@ public class LinearRegression { * @return a number indicating the error rate */ protected static double calculateCost(double[][] x, double[] y, double[] theta){ + double[] hypothesis = calculateHypothesis(x, theta); + double[] normalized = Matrix.subtract(hypothesis, y); + return 1.0 / (2.0 * x.length) * Matrix.sum( - Matrix.Elemental.pow( - Matrix.subtract(calculateHypotesis(x, theta), y), - 2)); + Matrix.Elemental.pow(normalized,2)); } /** @@ -45,13 +46,14 @@ public class LinearRegression { * * @return the theta that was found to minimize the cost function */ - public static double[] gradientAscent(double[][] x, double[] y, double[] theta, double alpha){ + public static double[] gradientDescent(double[][] x, double[] y, double[] theta, double alpha){ double[] newTheta = new double[theta.length]; double m = y.length; - double[][] hypotesisCache = Matrix.subtract(calculateHypotesis(x, theta), y); + double[] hypothesis = calculateHypothesis(x, theta); + double[] normalized = Matrix.subtract(hypothesis, y); for (int j= 0; j < theta.length; j++) { - newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(hypotesisCache, Matrix.getColumn(x, j))); + newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(normalized, Matrix.getColumn(x, j))); } return newTheta; diff --git a/test/zutil/math/MatrixTest.java b/test/zutil/math/MatrixTest.java index bea2330..a19444a 100755 --- a/test/zutil/math/MatrixTest.java +++ b/test/zutil/math/MatrixTest.java @@ -106,8 +106,11 @@ public class MatrixTest { @Test public void vectorMultiply(){ assertArrayEquals( - new double[][]{{8},{14}}, - Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2})); + new double[]{1.4, 1.9, 2.4, 2.9}, + Matrix.multiply( + new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}}, + new double[]{0.1, 0.2, 0.3}), + 0.001); } @Test diff --git a/test/zutil/ml/LinearRegressionTest.java b/test/zutil/ml/LinearRegressionTest.java index 4ecb21f..c15a8e4 100755 --- a/test/zutil/ml/LinearRegressionTest.java +++ b/test/zutil/ml/LinearRegressionTest.java @@ -12,12 +12,12 @@ public class LinearRegressionTest { @Test public void calculateHypotesis() { - double[][] hypotesis = LinearRegression.calculateHypotesis( - /* x */ new double[][]{{1, 2}, {1, 3}, {1, 4}, {1, 5}}, - /* theta */ new double[]{0.1, 0.2} + double[] hypotesis = LinearRegression.calculateHypothesis( + /* x */ new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}}, + /* theta */ new double[]{0.1, 0.2, 0.3} ); - assertArrayEquals(new double[][]{{0.5}, {0.7}, {0.9}, {1.1}}, hypotesis); + assertArrayEquals(new double[]{1.4, 1.9, 2.4, 2.9}, hypotesis, 0.001); } @Test @@ -33,7 +33,7 @@ public class LinearRegressionTest { @Test public void gradientAscent() { - double[] theta = LinearRegression.gradientAscent( + double[] theta = LinearRegression.gradientDescent( // one iteration /* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}}, /* y */ new double[]{1, 6, 4, 2}, /* theta */ new double[]{0, 0}, diff --git a/test/zutil/test/ZutilAssert.java b/test/zutil/test/ZutilAssert.java new file mode 100644 index 0000000..0cec6ac --- /dev/null +++ b/test/zutil/test/ZutilAssert.java @@ -0,0 +1,47 @@ +package zutil.test; + +import org.junit.Assert; +import org.junit.internal.ArrayComparisonFailure; +import org.junit.internal.InexactComparisonCriteria; + +/** + * Some additional assert functions that are missing from JUnit + */ +public class ZutilAssert extends Assert { + + private ZutilAssert() {} + + /** + * Asserts that two short arrays are equal. If they are not, an + * {@link AssertionError} is thrown. + * + * @param expected double array with expected values. + * @param actual double array with actual values + */ + public static void assertArrayEquals(double[][] expected, double[][] actual, double delta) { + ZutilAssert.assertArrayEquals(null, expected, actual, delta); + } + + /** + * Asserts that two int arrays are equal. If they are not, an + * {@link AssertionError} is thrown with the given message. + * + * @param message the identifying message for the {@link AssertionError} (null + * okay) + * @param expected double array with expected values. + * @param actual double array with actual values + */ + public static void assertArrayEquals(String message, double[][] expected, + double[][] actual, double delta) throws ArrayComparisonFailure { + // If both arrays are referencing the same object or null + if (expected == actual) + return; + + // Check array lengths + if (expected.length != actual.length) + fail(message + ". The array lengths of the first dimensions do not match."); + + // Check all sub arrays + new InexactComparisonCriteria(delta).arrayEquals(message, expected, actual); + } +}