Initial commit of LinearRegression class

This commit is contained in:
Ziver Koc 2018-03-27 16:39:14 +02:00
parent 3c8c692b16
commit 8050170ee3
5 changed files with 306 additions and 80 deletions

View file

@ -4,7 +4,7 @@ package zutil.math;
* Some basic matrix match functions. * Some basic matrix match functions.
* Matrix definition: double[y][x]. * Matrix definition: double[y][x].
*/ */
public class MatrixMath { public class Matrix {
/*********************************************************************** /***********************************************************************
* Scalar * Scalar
@ -74,6 +74,8 @@ public class MatrixMath {
* Elemental * Elemental
**********************************************************************/ **********************************************************************/
public static class Elemental {
/** /**
* Element addition, each element in matrix1 will be * Element addition, each element in matrix1 will be
* added with the corresponding element in matrix2. * added with the corresponding element in matrix2.
@ -116,7 +118,7 @@ public class MatrixMath {
* *
* @return a new matrix with the result * @return a new matrix with the result
*/ */
public static double[][] elemMultiply(double[][] matrix1, double[][] matrix2){ public static double[][] multiply(double[][] matrix1, double[][] matrix2) {
elementalPreCheck(matrix1, matrix2); elementalPreCheck(matrix1, matrix2);
double[][] result = new double[matrix1.length][matrix1[0].length]; double[][] result = new double[matrix1.length][matrix1[0].length];
@ -128,17 +130,64 @@ public class MatrixMath {
return result; return result;
} }
/**
* Element exponential, each element in the vector will raised
* to the power of the exp paramter
*
* @return a new vector with the result
*/
public static double[] pow(double[] vector, double exp) {
double[] result = new double[vector.length];
for (int i = 0; i < vector.length; ++i) {
result[i] = Math.pow(vector[i], exp);
}
return result;
}
/**
* Element multiplication, each element in matrix1 will be
* multiplied with the corresponding element in matrix2.
*
* @return a new matrix with the result
*/
public static double[][] pow(double[][] matrix, double exp) {
double[][] result = new double[matrix.length][matrix[0].length];
for (int i = 0; i < matrix.length; ++i) {
result[i] = pow(matrix[i], exp);
}
return result;
}
private static void elementalPreCheck(double[][] matrix1, double[][] matrix2) { private static void elementalPreCheck(double[][] matrix1, double[][] matrix2) {
if (matrix1.length != matrix2.length || matrix1[0].length != matrix2[0].length) if (matrix1.length != matrix2.length || matrix1[0].length != matrix2[0].length)
throw new IllegalArgumentException("Matrices need to be of same dimension: " + throw new IllegalArgumentException("Matrices need to be of same dimension: " +
"matrix1 " + matrix1.length + "x" + matrix1[0].length + ", " + "matrix1 " + matrix1.length + "x" + matrix1[0].length + ", " +
"matrix2 " + matrix2.length + "x" + matrix2[0].length + ", "); "matrix2 " + matrix2.length + "x" + matrix2[0].length + ", ");
} }
}
/*********************************************************************** /***********************************************************************
* Vector * Vector
**********************************************************************/ **********************************************************************/
/**
* Vector subtraction, every element in the first vector will be subtracted
* with the corresponding element in the second vector.
*
* @return a new vector with subtracted elements
*/
public static double[] subtract(double[] vector1, double[] vector2){
vectorPreCheck(vector1, vector2);
double[] result = new double[vector1.length];
for (int i=0; i < result.length; ++i) {
result[i] = vector1[i] - vector2[i];
}
return result;
}
/** /**
* Matrix Vector multiplication, each element column in the matrix will be * Matrix Vector multiplication, each element column in the matrix will be
* multiplied with the corresponding element row in the vector. * multiplied with the corresponding element row in the vector.
@ -175,6 +224,24 @@ public class MatrixMath {
return result; return result;
} }
/**
* Sums all values in a vector
*
* @return the summed value of all elements in the vector
*/
public static double sum(double[] vector) {
double sum = 0;
for (int i = 0; i < vector.length; i++) {
sum += vector[i];
}
return sum;
}
private static void vectorPreCheck(double[] vector1, double[] vector2) {
if (vector1.length != vector2.length)
throw new IllegalArgumentException("The two vectors need to have the same length: " +
"vector1 " + vector1.length + "x1, vector2 " + vector2.length + "x1");
}
private static void vectorPreCheck(double[][] matrix, double[] vector) { private static void vectorPreCheck(double[][] matrix, double[] vector) {
if (matrix[0].length != vector.length) if (matrix[0].length != vector.length)
throw new IllegalArgumentException("Matrix columns need to have same length as vector length: " + throw new IllegalArgumentException("Matrix columns need to have same length as vector length: " +
@ -196,11 +263,10 @@ public class MatrixMath {
matrixPreCheck(matrix1, matrix2); matrixPreCheck(matrix1, matrix2);
double[][] result = new double[matrix1.length][matrix2[0].length]; double[][] result = new double[matrix1.length][matrix2[0].length];
for (int y=0; y < result.length; ++y) { for (int i=0; i < result.length; ++i) {
for (int x=0; x < result[y].length; ++x){ for (int k=0; k<matrix1[0].length; ++k) {
for (int j=0; j < result[i].length; ++j){
for (int i=0; i<matrix1[0].length; ++i) { result[i][j] += matrix1[i][k] * matrix2[k][j];
result[y][x] += matrix1[y][i] * matrix2[i][x];
} }
} }
} }
@ -221,6 +287,18 @@ public class MatrixMath {
return result; return result;
} }
/**
* Sums all values in a matrix
*
* @return the summed value of all elements
*/
public static double sum(double[][] matrix) {
double sum = 0;
for (int i = 0; i < matrix.length; i++) {
sum += sum(matrix[i]);
}
return sum;
}
private static void matrixPreCheck(double[][] matrix1, double[][] matrix2) { private static void matrixPreCheck(double[][] matrix1, double[][] matrix2) {
if (matrix1[0].length != matrix2.length) if (matrix1[0].length != matrix2.length)

View file

@ -0,0 +1,38 @@
package zutil.ml;
import zutil.math.Matrix;
/**
* Implementation of a Linear Regression algorithm for "predicting"
* numerical values depending on specific input
*/
public class LinearRegression {
/**
* Method for calculating a hypothesis value fr a specific input value x.
* <br><br>
* <i>
* h(x) = theta0 * x0 + theta1 * x1 + ... + thetan * xn => transpose(theta) * x
* </i>
*/
protected static double[] calculateHypotesis(double[][] x, double[] theta){
return Matrix.multiply(x, theta);
}
/**
* Linear Regresion cost method.
* <br /><br />
* <i>
* J(O) = 1 / (2 * m) * Σ { ( h(xi) - yi )^2 }
* </i><br>
* m = learning data size (rows)
* @return a number indicating the error rate
*/
protected static double calculateCost(double[][] x, double[] y, double[] theta){
return 1 / (2 * x.length) * Matrix.sum(
Matrix.Elemental.pow(
Matrix.subtract(calculateHypotesis(x, theta), y),
2));
}
}

View file

@ -0,0 +1,57 @@
package zutil.benchmark;
import com.carrotsearch.junitbenchmarks.BenchmarkRule;
import org.junit.Rule;
import org.junit.Test;
public class AnonymousFunctionBenchmark {
public static final int TEST_EXECUTIONS = 500;
@Rule
public BenchmarkRule benchmarkRun = new BenchmarkRule();
private int[] array = new int[100_000];
@Test
public void functionLoop() {
for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < array.length; i++) {
array[i] = new CalcFunc(){
public int calc(int i){
return i+1;
}
}.calc(i);
}
}
}
@Test
public void preFunctionLoop() {
CalcFunc func = new CalcFunc(){
public int calc(int i){
return i+1;
}
};
for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < array.length; i++) {
array[i] = func.calc(i);
}
}
}
@Test
public void rawLoops(){
for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < array.length; i++) {
array[i] = i;
}
}
}
private interface CalcFunc{
int calc(int i);
}
}

View file

@ -11,30 +11,51 @@ public class LoopBenchmark {
public BenchmarkRule benchmarkRun = new BenchmarkRule(); public BenchmarkRule benchmarkRun = new BenchmarkRule();
private int[] matrix = new int[100_000]; private int[] array1 = new int[100_000];
private int[] matrix2 = new int[50_000]; private int[] array2 = new int[50_000];
@Test @Test
public void oneLoop() { public void writeArrayOneLoop() {
for(int k=0; k<TEST_EXECUTIONS; k++) { for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < Math.max(matrix.length, matrix.length); i++) { for (int i = 0; i < Math.max(array1.length, array1.length); i++) {
if (i < matrix.length) if (i < array1.length)
matrix[i] = i; array1[i] = i;
if (i < matrix2.length) if (i < array2.length)
matrix2[i] = i; array2[i] = i;
} }
} }
} }
@Test @Test
public void twoLoops(){ public void writeArraySeparateLoops(){
for(int k=0; k<TEST_EXECUTIONS; k++) { for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < matrix.length; i++) { for (int i = 0; i < array1.length; i++) {
matrix[i] = i; array1[i] = i;
} }
for (int j = 0; j < matrix2.length; j++) { for (int j = 0; j < array2.length; j++) {
matrix2[j] = j; array2[j] = j;
}
}
}
@Test
public void readArrayLoop() {
int sum = 0;
for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i = 0; i < array1.length; i++) {
sum += array1[i];
}
}
}
@Test
public void readArrayForeach() {
int sum = 0;
for(int k=0; k<TEST_EXECUTIONS; k++) {
for (int i : array1) {
sum += array1[i];
} }
} }
} }

View file

@ -7,30 +7,30 @@ import static org.junit.Assert.*;
/** /**
* *
*/ */
public class MatrixMathTest { public class MatrixTest {
@Test @Test
public void scalarAdd(){ public void scalarAdd(){
assertArrayEquals(new double[][]{{4,5},{-2,11}}, assertArrayEquals(new double[][]{{4,5},{-2,11}},
MatrixMath.add(new double[][]{{2,3},{-4,9}}, 2)); Matrix.add(new double[][]{{2,3},{-4,9}}, 2));
} }
@Test @Test
public void scalarSubtraction(){ public void scalarSubtraction(){
assertArrayEquals(new double[][]{{0,1},{-6,7}}, assertArrayEquals(new double[][]{{0,1},{-6,7}},
MatrixMath.subtract(new double[][]{{2,3},{-4,9}}, 2)); Matrix.subtract(new double[][]{{2,3},{-4,9}}, 2));
} }
@Test @Test
public void scalarMultiply(){ public void scalarMultiply(){
assertArrayEquals(new double[][]{{4,6},{-8,18}}, assertArrayEquals(new double[][]{{4,6},{-8,18}},
MatrixMath.multiply(new double[][]{{2,3},{-4,9}}, 2)); Matrix.multiply(new double[][]{{2,3},{-4,9}}, 2));
} }
@Test @Test
public void scalarDivision(){ public void scalarDivision(){
assertArrayEquals(new double[][]{{1,2},{-2,5}}, assertArrayEquals(new double[][]{{1,2},{-2,5}},
MatrixMath.divide(new double[][]{{2,4},{-4,10}}, 2)); Matrix.divide(new double[][]{{2,4},{-4,10}}, 2));
} }
@ -38,19 +38,33 @@ public class MatrixMathTest {
@Test @Test
public void elementalAdd(){ public void elementalAdd(){
assertArrayEquals(new double[][]{{3,5},{-1,13}}, assertArrayEquals(new double[][]{{3,5},{-1,13}},
MatrixMath.add(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}})); Matrix.Elemental.add(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}}));
} }
@Test @Test
public void elementalSubtract(){ public void elementalSubtract(){
assertArrayEquals(new double[][]{{1,1},{-7,5}}, assertArrayEquals(new double[][]{{1,1},{-7,5}},
MatrixMath.subtract(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}})); Matrix.Elemental.subtract(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}}));
} }
@Test @Test
public void elementalMultiply(){ public void elementalMultiply(){
assertArrayEquals(new double[][]{{2,6},{-12,36}}, assertArrayEquals(new double[][]{{2,6},{-12,36}},
MatrixMath.elemMultiply(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}})); Matrix.Elemental.multiply(new double[][]{{2,3},{-4,9}}, new double[][]{{1,2},{3,4}}));
}
@Test
public void elementalVectorPow(){
assertArrayEquals(
new double[]{4,9,16,81},
Matrix.Elemental.pow(new double[]{2,3,-4,9}, 2),
0.0);
}
@Test
public void elementalMatrixPow(){
assertArrayEquals(new double[][]{{4,9},{16,81}},
Matrix.Elemental.pow(new double[][]{{2,3},{-4,9}}, 2));
} }
@ -59,7 +73,7 @@ public class MatrixMathTest {
public void vectorMultiply(){ public void vectorMultiply(){
assertArrayEquals( assertArrayEquals(
new double[]{8,14}, new double[]{8,14},
MatrixMath.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2}), Matrix.multiply(new double[][]{{2,3},{-4,9}}, new double[]{1,2}),
0.0 0.0
); );
} }
@ -68,18 +82,27 @@ public class MatrixMathTest {
public void vectorDivision(){ public void vectorDivision(){
assertArrayEquals( assertArrayEquals(
new double[]{4,1}, new double[]{4,1},
MatrixMath.divide(new double[][]{{2,4},{-4,10}}, new double[]{1,2}), Matrix.divide(new double[][]{{2,4},{-4,10}}, new double[]{1,2}),
0.0 0.0
); );
} }
@Test
public void vectorSum(){
assertEquals(
20.0,
Matrix.sum(new double[]{1,2,0,3,5,9}),
0.02
);
}
@Test @Test
public void matrixMultiply(){ public void matrixMultiply(){
assertArrayEquals( assertArrayEquals(
new double[][]{{486,410.4,691.6},{314,341.6,416.4},{343.5,353.4,463.6},{173,285.2,190.8}}, new double[][]{{486,410.4,691.6},{314,341.6,416.4},{343.5,353.4,463.6},{173,285.2,190.8}},
MatrixMath.multiply( Matrix.multiply(
new double[][]{{1,2104},{1,1416},{1,1534},{1,852}}, new double[][]{{1,2104},{1,1416},{1,1534},{1,852}},
new double[][]{{-40,200,-150},{0.25,0.1,0.4}}) new double[][]{{-40,200,-150},{0.25,0.1,0.4}})
); );
@ -89,21 +112,30 @@ public class MatrixMathTest {
public void matrixTranspose(){ public void matrixTranspose(){
assertArrayEquals( assertArrayEquals(
new double[][]{{1,3},{2,5},{0,9}}, new double[][]{{1,3},{2,5},{0,9}},
MatrixMath.transpose( Matrix.transpose(
new double[][]{{1,2,0},{3,5,9}}) new double[][]{{1,2,0},{3,5,9}})
); );
} }
@Test
public void matrixSum(){
assertEquals(
20.0,
Matrix.sum(new double[][]{{1,2,0},{3,5,9}}),
0.02
);
}
@Test @Test
public void identity(){ public void identity(){
assertArrayEquals( assertArrayEquals(
new double[][]{{1}}, new double[][]{{1}},
MatrixMath.identity(1)); Matrix.identity(1));
assertArrayEquals( assertArrayEquals(
new double[][]{{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}}, new double[][]{{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}},
MatrixMath.identity(4)); Matrix.identity(4));
} }
} }