Fixed some ml test cases

This commit is contained in:
Ziver Koc 2018-05-21 16:06:46 +02:00
parent 2160976406
commit 45b1f51685
6 changed files with 83 additions and 23 deletions

View file

@ -45,5 +45,15 @@
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.12" level="project" /> <orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.12" level="project" />
<orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" /> <orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
<orderEntry type="library" name="Maven: com.carrotsearch:junit-benchmarks:0.7.2" level="project" /> <orderEntry type="library" name="Maven: com.carrotsearch:junit-benchmarks:0.7.2" level="project" />
<orderEntry type="library" name="Maven: commons-fileupload:commons-fileupload:1.2.1" level="project" />
<orderEntry type="library" name="Maven: commons-io:commons-io:2.5" level="project" />
<orderEntry type="library" name="Maven: dom4j:dom4j:1.6.1" level="project" />
<orderEntry type="library" name="Maven: xml-apis:xml-apis:1.0.b2" level="project" />
<orderEntry type="library" scope="PROVIDED" name="Maven: javax.servlet:javax.servlet-api:3.1.0" level="project" />
<orderEntry type="library" name="Maven: mysql:mysql-connector-java:5.1.36" level="project" />
<orderEntry type="library" name="Maven: org.xerial:sqlite-jdbc:3.8.11.2" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.12" level="project" />
<orderEntry type="library" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
<orderEntry type="library" name="Maven: com.carrotsearch:junit-benchmarks:0.7.2" level="project" />
</component> </component>
</module> </module>

View file

@ -226,7 +226,7 @@ public class Matrix {
* Matrix Vector subtraction, each column in the matrix will be subtracted * Matrix Vector subtraction, each column in the matrix will be subtracted
* with the vector. * with the vector.
* *
* @return a new vector with subtracted elements * @return a new matrix with subtracted elements
*/ */
public static double[][] subtract(double[][] matrix, double[] vector){ public static double[][] subtract(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector); vectorPreCheck(matrix, vector);
@ -247,15 +247,13 @@ public class Matrix {
* *
* @return a new vector with the result * @return a new vector with the result
*/ */
public static double[][] multiply(double[][] matrix, double[] vector){ public static double[] multiply(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector); vectorPreCheck(matrix, vector);
double[][] result = new double[matrix.length][1]; double[] result = new double[matrix.length];
for (int y=0; y < result.length; ++y) { for (int y=0; y < matrix.length; ++y) {
for (int x=0; x<matrix[0].length; ++x) { for (int x=0; x<matrix[0].length; ++x) {
for (int i=0; i < result[y].length; ++i){ result[y] += matrix[y][x] * vector[x];
result[y][i] += matrix[y][x] * vector[x];
}
} }
} }
return result; return result;
@ -299,7 +297,7 @@ public class Matrix {
} }
private static void vectorPreCheck(double[][] matrix, double[] vector) { private static void vectorPreCheck(double[][] matrix, double[] vector) {
if (matrix[0].length != vector.length) if (matrix[0].length != vector.length)
throw new IllegalArgumentException("Matrix columns need to have same length as vector length: " + throw new IllegalArgumentException("Matrix columns need to have same length as the vector length: " +
"matrix " + matrix.length + "x" + matrix[0].length + ", " + "matrix " + matrix.length + "x" + matrix[0].length + ", " +
"vector " + vector.length + "x1"); "vector " + vector.length + "x1");
} }

View file

@ -16,12 +16,12 @@ public class LinearRegression {
* h(x) = theta0 * x0 + theta1 * x1 + ... + thetan * xn => transpose(theta) * x * h(x) = theta0 * x0 + theta1 * x1 + ... + thetan * xn => transpose(theta) * x
* </i> * </i>
*/ */
protected static double[][] calculateHypotesis(double[][] x, double[] theta){ protected static double[] calculateHypothesis(double[][] x, double[] theta){
return Matrix.multiply(x, theta); return Matrix.multiply(x, theta);
} }
/** /**
* Linear Regresion cost method. * Linear Regression cost method.
* <br /><br /> * <br /><br />
* <i> * <i>
* J(O) = 1 / (2 * m) * Σ { ( h(Xi) - Yi )^2 } * J(O) = 1 / (2 * m) * Σ { ( h(Xi) - Yi )^2 }
@ -30,10 +30,11 @@ public class LinearRegression {
* @return a number indicating the error rate * @return a number indicating the error rate
*/ */
protected static double calculateCost(double[][] x, double[] y, double[] theta){ protected static double calculateCost(double[][] x, double[] y, double[] theta){
double[] hypothesis = calculateHypothesis(x, theta);
double[] normalized = Matrix.subtract(hypothesis, y);
return 1.0 / (2.0 * x.length) * Matrix.sum( return 1.0 / (2.0 * x.length) * Matrix.sum(
Matrix.Elemental.pow( Matrix.Elemental.pow(normalized,2));
Matrix.subtract(calculateHypotesis(x, theta), y),
2));
} }
/** /**
@ -45,13 +46,14 @@ public class LinearRegression {
* *
* @return the theta that was found to minimize the cost function * @return the theta that was found to minimize the cost function
*/ */
public static double[] gradientAscent(double[][] x, double[] y, double[] theta, double alpha){ public static double[] gradientDescent(double[][] x, double[] y, double[] theta, double alpha){
double[] newTheta = new double[theta.length]; double[] newTheta = new double[theta.length];
double m = y.length; double m = y.length;
double[][] hypotesisCache = Matrix.subtract(calculateHypotesis(x, theta), y); double[] hypothesis = calculateHypothesis(x, theta);
double[] normalized = Matrix.subtract(hypothesis, y);
for (int j= 0; j < theta.length; j++) { for (int j= 0; j < theta.length; j++) {
newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(hypotesisCache, Matrix.getColumn(x, j))); newTheta[j] = theta[j] - alpha * (1.0/m) * Matrix.sum(Matrix.add(normalized, Matrix.getColumn(x, j)));
} }
return newTheta; return newTheta;

View file

@ -106,8 +106,11 @@ public class MatrixTest {
@Test @Test
public void vectorMultiply(){ public void vectorMultiply(){
assertArrayEquals( assertArrayEquals(
new double[][]{{8},{14}}, new double[]{1.4, 1.9, 2.4, 2.9},
Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2})); Matrix.multiply(
new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}},
new double[]{0.1, 0.2, 0.3}),
0.001);
} }
@Test @Test

View file

@ -12,12 +12,12 @@ public class LinearRegressionTest {
@Test @Test
public void calculateHypotesis() { public void calculateHypotesis() {
double[][] hypotesis = LinearRegression.calculateHypotesis( double[] hypotesis = LinearRegression.calculateHypothesis(
/* x */ new double[][]{{1, 2}, {1, 3}, {1, 4}, {1, 5}}, /* x */ new double[][]{{1, 2, 3}, {1, 3, 4}, {1, 4, 5}, {1, 5, 6}},
/* theta */ new double[]{0.1, 0.2} /* theta */ new double[]{0.1, 0.2, 0.3}
); );
assertArrayEquals(new double[][]{{0.5}, {0.7}, {0.9}, {1.1}}, hypotesis); assertArrayEquals(new double[]{1.4, 1.9, 2.4, 2.9}, hypotesis, 0.001);
} }
@Test @Test
@ -33,7 +33,7 @@ public class LinearRegressionTest {
@Test @Test
public void gradientAscent() { public void gradientAscent() {
double[] theta = LinearRegression.gradientAscent( double[] theta = LinearRegression.gradientDescent( // one iteration
/* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}}, /* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}},
/* y */ new double[]{1, 6, 4, 2}, /* y */ new double[]{1, 6, 4, 2},
/* theta */ new double[]{0, 0}, /* theta */ new double[]{0, 0},

View file

@ -0,0 +1,47 @@
package zutil.test;
import org.junit.Assert;
import org.junit.internal.ArrayComparisonFailure;
import org.junit.internal.InexactComparisonCriteria;
/**
* Some additional assert functions that are missing from JUnit
*/
public class ZutilAssert extends Assert {
private ZutilAssert() {}
/**
* Asserts that two short arrays are equal. If they are not, an
* {@link AssertionError} is thrown.
*
* @param expected double array with expected values.
* @param actual double array with actual values
*/
public static void assertArrayEquals(double[][] expected, double[][] actual, double delta) {
ZutilAssert.assertArrayEquals(null, expected, actual, delta);
}
/**
* Asserts that two int arrays are equal. If they are not, an
* {@link AssertionError} is thrown with the given message.
*
* @param message the identifying message for the {@link AssertionError} (<code>null</code>
* okay)
* @param expected double array with expected values.
* @param actual double array with actual values
*/
public static void assertArrayEquals(String message, double[][] expected,
double[][] actual, double delta) throws ArrayComparisonFailure {
// If both arrays are referencing the same object or null
if (expected == actual)
return;
// Check array lengths
if (expected.length != actual.length)
fail(message + ". The array lengths of the first dimensions do not match.");
// Check all sub arrays
new InexactComparisonCriteria(delta).arrayEquals(message, expected, actual);
}
}