Fixed gradient decent
This commit is contained in:
parent
cf94310598
commit
98f2219366
2 changed files with 44 additions and 22 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package zutil.ml;
|
||||
|
||||
import org.junit.Test;
|
||||
import zutil.io.MultiPrintStream;
|
||||
import zutil.log.LogUtil;
|
||||
|
||||
import java.util.logging.Level;
|
||||
|
|
@ -36,8 +37,8 @@ public class LinearRegressionTest {
|
|||
}
|
||||
|
||||
// Does not work
|
||||
//@Test
|
||||
public void gradientAscent() {
|
||||
@Test
|
||||
public void gradientDescent() {
|
||||
double[][] x = {
|
||||
{1.0, 0.1, 0.6, 1.1},
|
||||
{1.0, 0.2, 0.7, 1.2},
|
||||
|
|
@ -59,14 +60,40 @@ public class LinearRegressionTest {
|
|||
2
|
||||
};
|
||||
|
||||
double[] resultTheta = LinearRegression.gradientDescent(x, y, theta, 0);
|
||||
// Alpha zero
|
||||
|
||||
assertEquals(0.73482, LinearRegression.calculateCost(x, y, resultTheta), 0.000001);
|
||||
double[] resultTheta = LinearRegression.gradientDescent(x, y, theta, 0);
|
||||
System.out.println("Result Theta (alpha = 0):");
|
||||
System.out.println(MultiPrintStream.dumpToString(resultTheta));
|
||||
|
||||
assertArrayEquals(theta, resultTheta, 0.000001);
|
||||
|
||||
// Alpha +
|
||||
|
||||
resultTheta = LinearRegression.gradientDescent(x, y, theta, 0.1);
|
||||
System.out.println("Result Theta (alpha = 0.1):");
|
||||
System.out.println(MultiPrintStream.dumpToString(resultTheta));
|
||||
|
||||
assertArrayEquals(
|
||||
new double[]{-1.31221, -1.98259, 0.36131, 1.70520},
|
||||
resultTheta, 0.001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void gradientAscentIteration() {
|
||||
double[] theta = LinearRegression.gradientDescentIteration( // one iteration
|
||||
public void gradientDescentIteration() {
|
||||
// Zero iterations
|
||||
|
||||
double[] theta = LinearRegression.gradientDescentIteration(
|
||||
/* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}},
|
||||
/* y */ new double[]{1, 6, 4, 2},
|
||||
/* theta */ new double[]{0, 0},
|
||||
/* alpha */0.0);
|
||||
|
||||
assertArrayEquals(new double[]{0.0, 0.0}, theta, 0.000001);
|
||||
|
||||
// One iteration
|
||||
|
||||
theta = LinearRegression.gradientDescentIteration(
|
||||
/* x */ new double[][]{{1, 5},{1, 2},{1, 4},{1, 5}},
|
||||
/* y */ new double[]{1, 6, 4, 2},
|
||||
/* theta */ new double[]{0, 0},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue