Implemented Gradient Ascent (TCs failing)
This commit is contained in:
parent
8050170ee3
commit
2160976406
4 changed files with 187 additions and 15 deletions
|
|
@ -70,14 +70,46 @@ public class MatrixTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void vectorMultiply(){
|
||||
public void vectorAddition(){
|
||||
assertArrayEquals(
|
||||
new double[]{8,14},
|
||||
Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2}),
|
||||
new double[]{3,5,-1,13},
|
||||
Matrix.add(new double[]{2,3,-4,9}, new double[]{1,2,3,4}),
|
||||
0.0
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vectorMatrixAddition(){
|
||||
assertArrayEquals(
|
||||
new double[][]{{2,3,4,5},{2,3,4,5},{2,3,4,5},{2,3,4,5}},
|
||||
Matrix.add(new double[][]{{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}}, new double[]{1,1,1,1})
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vectorSubtraction(){
|
||||
assertArrayEquals(
|
||||
new double[]{1,1,-7,5},
|
||||
Matrix.subtract(new double[]{2,3,-4,9}, new double[]{1,2,3,4}),
|
||||
0.0
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vectorMatrixSubtraction(){
|
||||
assertArrayEquals(
|
||||
new double[][]{{0,1,2,3},{0,1,2,3},{0,1,2,3},{0,1,2,3}},
|
||||
Matrix.subtract(new double[][]{{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}}, new double[]{1,1,1,1})
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vectorMultiply(){
|
||||
assertArrayEquals(
|
||||
new double[][]{{8},{14}},
|
||||
Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2}));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vectorDivision(){
|
||||
assertArrayEquals(
|
||||
|
|
@ -138,4 +170,13 @@ public class MatrixTest {
|
|||
new double[][]{{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}},
|
||||
Matrix.identity(4));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getColumn(){
|
||||
assertArrayEquals(
|
||||
new double[]{2,3,4,1},
|
||||
Matrix.getColumn(new double[][]{{1,2,3,4},{2,3,4,1},{3,4,1,2},{4,1,2,3}}, 1),
|
||||
0.0
|
||||
);
|
||||
}
|
||||
}
|
||||
44
test/zutil/ml/LinearRegressionTest.java
Executable file
44
test/zutil/ml/LinearRegressionTest.java
Executable file
|
|
@ -0,0 +1,44 @@
|
|||
package zutil.ml;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Test cases are from the Machine Learning course on coursera.
|
||||
* https://www.coursera.org/learn/machine-learning/discussions/all/threads/0SxufTSrEeWPACIACw4G5w
|
||||
*/
|
||||
public class LinearRegressionTest {
|
||||
|
||||
@Test
|
||||
public void calculateHypotesis() {
|
||||
double[][] hypotesis = LinearRegression.calculateHypotesis(
|
||||
/* x */ new double[][]{{1, 2}, {1, 3}, {1, 4}, {1, 5}},
|
||||
/* theta */ new double[]{0.1, 0.2}
|
||||
);
|
||||
|
||||
assertArrayEquals(new double[][]{{0.5}, {0.7}, {0.9}, {1.1}}, hypotesis);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void calculateCost() {
|
||||
double cost = LinearRegression.calculateCost(
|
||||
/* x */ new double[][]{{1, 2}, {1, 3}, {1, 4}, {1, 5}},
|
||||
/* y */ new double[]{7, 6, 5, 4},
|
||||
/* theta */ new double[]{0.1, 0.2}
|
||||
);
|
||||
|
||||
assertEquals(11.9450, cost, 0.0001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void gradientAscent() {
|
||||
double[] theta = LinearRegression.gradientAscent(
|
||||
/* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}},
|
||||
/* y */ new double[]{1, 6, 4, 2},
|
||||
/* theta */ new double[]{0, 0},
|
||||
/* alpha */0.01);
|
||||
|
||||
assertArrayEquals(new double[]{0.032500, 0.107500}, theta, 0.000001);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue