diff --git a/Zutil.iml b/Zutil.iml
index a2bde11..5454eb9 100755
--- a/Zutil.iml
+++ b/Zutil.iml
@@ -45,5 +45,15 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/zutil/math/Matrix.java b/src/zutil/math/Matrix.java
index 9ec3e74..6ca6d76 100755
--- a/src/zutil/math/Matrix.java
+++ b/src/zutil/math/Matrix.java
@@ -226,7 +226,7 @@ public class Matrix {
* Matrix Vector subtraction, each column in the matrix will be subtracted
* with the vector.
*
- * @return a new vector with subtracted elements
+ * @return a new matrix with subtracted elements
*/
public static double[][] subtract(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector);
@@ -247,15 +247,13 @@ public class Matrix {
*
* @return a new vector with the result
*/
- public static double[][] multiply(double[][] matrix, double[] vector){
+ public static double[] multiply(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector);
- double[][] result = new double[matrix.length][1];
+ double[] result = new double[matrix.length];
- for (int y=0; y < result.length; ++y) {
+ for (int y=0; y < matrix.length; ++y) {
for (int x=0; x transpose(theta) * x
*
*/
- protected static double[][] calculateHypotesis(double[][] x, double[] theta){
+ protected static double[] calculateHypothesis(double[][] x, double[] theta){
return Matrix.multiply(x, theta);
}
/**
- * Linear Regresion cost method.
+ * Linear Regression cost method.
*
*
* J(O) = 1 / (2 * m) * Σ { ( h(Xi) - Yi )^2 }
@@ -30,10 +30,11 @@ public class LinearRegression {
* @return a number indicating the error rate
*/
protected static double calculateCost(double[][] x, double[] y, double[] theta){
+ double[] hypothesis = calculateHypothesis(x, theta);
+ double[] normalized = Matrix.subtract(hypothesis, y);
+
return 1.0 / (2.0 * x.length) * Matrix.sum(
- Matrix.Elemental.pow(
- Matrix.subtract(calculateHypotesis(x, theta), y),
- 2));
+ Matrix.Elemental.pow(normalized,2));
}
/**
@@ -45,13 +46,14 @@ public class LinearRegression {
*
* @return the theta that was found to minimize the cost function
*/
- public static double[] gradientAscent(double[][] x, double[] y, double[] theta, double alpha){
+ public static double[] gradientDescent(double[][] x, double[] y, double[] theta, double alpha){
double[] newTheta = new double[theta.length];
double m = y.length;
- double[][] hypotesisCache = Matrix.subtract(calculateHypotesis(x, theta), y);
+ double[] hypothesis = calculateHypothesis(x, theta);
+ double[] normalized = Matrix.subtract(hypothesis, y);
for (int j= 0; j < theta.length; j++) {
- newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(hypotesisCache, Matrix.getColumn(x, j)));
+ newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(normalized, Matrix.getColumn(x, j)));
}
return newTheta;
diff --git a/test/zutil/math/MatrixTest.java b/test/zutil/math/MatrixTest.java
index bea2330..a19444a 100755
--- a/test/zutil/math/MatrixTest.java
+++ b/test/zutil/math/MatrixTest.java
@@ -106,8 +106,11 @@ public class MatrixTest {
@Test
public void vectorMultiply(){
assertArrayEquals(
- new double[][]{{8},{14}},
- Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2}));
+ new double[]{1.4, 1.9, 2.4, 2.9},
+ Matrix.multiply(
+ new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}},
+ new double[]{0.1, 0.2, 0.3}),
+ 0.001);
}
@Test
diff --git a/test/zutil/ml/LinearRegressionTest.java b/test/zutil/ml/LinearRegressionTest.java
index 4ecb21f..c15a8e4 100755
--- a/test/zutil/ml/LinearRegressionTest.java
+++ b/test/zutil/ml/LinearRegressionTest.java
@@ -12,12 +12,12 @@ public class LinearRegressionTest {
@Test
public void calculateHypotesis() {
- double[][] hypotesis = LinearRegression.calculateHypotesis(
- /* x */ new double[][]{{1, 2}, {1, 3}, {1, 4}, {1, 5}},
- /* theta */ new double[]{0.1, 0.2}
+ double[] hypotesis = LinearRegression.calculateHypothesis(
+ /* x */ new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}},
+ /* theta */ new double[]{0.1, 0.2, 0.3}
);
- assertArrayEquals(new double[][]{{0.5}, {0.7}, {0.9}, {1.1}}, hypotesis);
+ assertArrayEquals(new double[]{1.4, 1.9, 2.4, 2.9}, hypotesis, 0.001);
}
@Test
@@ -33,7 +33,7 @@ public class LinearRegressionTest {
@Test
public void gradientAscent() {
- double[] theta = LinearRegression.gradientAscent(
+ double[] theta = LinearRegression.gradientDescent( // one iteration
/* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}},
/* y */ new double[]{1, 6, 4, 2},
/* theta */ new double[]{0, 0},
diff --git a/test/zutil/test/ZutilAssert.java b/test/zutil/test/ZutilAssert.java
new file mode 100644
index 0000000..0cec6ac
--- /dev/null
+++ b/test/zutil/test/ZutilAssert.java
@@ -0,0 +1,47 @@
+package zutil.test;
+
+import org.junit.Assert;
+import org.junit.internal.ArrayComparisonFailure;
+import org.junit.internal.InexactComparisonCriteria;
+
+/**
+ * Some additional assert functions that are missing from JUnit
+ */
+public class ZutilAssert extends Assert {
+
+ private ZutilAssert() {}
+
+ /**
+ * Asserts that two short arrays are equal. If they are not, an
+ * {@link AssertionError} is thrown.
+ *
+ * @param expected double array with expected values.
+ * @param actual double array with actual values
+ */
+ public static void assertArrayEquals(double[][] expected, double[][] actual, double delta) {
+ ZutilAssert.assertArrayEquals(null, expected, actual, delta);
+ }
+
+ /**
+ * Asserts that two int arrays are equal. If they are not, an
+ * {@link AssertionError} is thrown with the given message.
+ *
+ * @param message the identifying message for the {@link AssertionError} (null
+ * okay)
+ * @param expected double array with expected values.
+ * @param actual double array with actual values
+ */
+ public static void assertArrayEquals(String message, double[][] expected,
+ double[][] actual, double delta) throws ArrayComparisonFailure {
+ // If both arrays are referencing the same object or null
+ if (expected == actual)
+ return;
+
+ // Check array lengths
+ if (expected.length != actual.length)
+ fail(message + ". The array lengths of the first dimensions do not match.");
+
+ // Check all sub arrays
+ new InexactComparisonCriteria(delta).arrayEquals(message, expected, actual);
+ }
+}