Added more elemental matrix math and a non functioning gradient decent function.

This commit is contained in:
Ziver Koc 2018-09-05 16:37:19 +02:00
parent f84901dcc3
commit cf94310598
4 changed files with 123 additions and 6 deletions

View file

@ -146,6 +146,42 @@ public class Matrix {
return result;
}
/**
* Matrix Vector elemental multiplication, each element column in the matrix will be
* multiplied with the column of the vector.
*
* @return a new matrix with the result
*/
public static double[][] multiply(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector);
double[][] result = new double[matrix.length][matrix[0].length];
for (int y=0; y < matrix.length; ++y) {
for (int x=0; x<matrix[0].length; ++x) {
result[y][x] = matrix[y][x] * vector[x];
}
}
return result;
}
/**
* Matrix Vector elemental division, each element column in the matrix will be
* divided with the column of the vector.
*
* @return a new matrix with the result
*/
public static double[][] divide(double[][] matrix, double[] vector){
vectorPreCheck(matrix, vector);
double[][] result = new double[matrix.length][matrix[0].length];
for (int y=0; y < matrix.length; ++y) {
for (int x=0; x<matrix[0].length; ++x) {
result[y][x] = matrix[y][x] / vector[x];
}
}
return result;
}
/**
* Element exponential, each element in the vector will raised
* to the power of the exp paramter
@ -181,7 +217,7 @@ public class Matrix {
if (matrix1.length != matrix2.length || matrix1[0].length != matrix2[0].length)
throw new IllegalArgumentException("Matrices need to be of same dimension: " +
"matrix1 " + matrix1.length + "x" + matrix1[0].length + ", " +
"matrix2 " + matrix2.length + "x" + matrix2[0].length + ", ");
"matrix2 " + matrix2.length + "x" + matrix2[0].length);
}
}
@ -374,7 +410,7 @@ public class Matrix {
if (matrix1[0].length != matrix2.length)
throw new IllegalArgumentException("Matrix1 columns need to match Matrix2 rows: " +
"matrix1 " + matrix1.length + "x" + matrix1[0].length + ", " +
"matrix2 " + matrix2.length + "x" + matrix2[0].length + ", ");
"matrix2 " + matrix2.length + "x" + matrix2[0].length);
}
/***********************************************************************

View file

@ -1,13 +1,16 @@
package zutil.ml;
import zutil.log.LogUtil;
import zutil.math.Matrix;
import java.util.logging.Logger;
/**
* Implementation of a Linear Regression algorithm for "predicting"
* Implementation of a Linear Regression algorithm for predicting
* numerical values depending on specific input
*/
public class LinearRegression {
private static final Logger logger = LogUtil.getLogger();
/**
* Method for calculating a hypothesis value fr a specific input value x.
@ -37,6 +40,34 @@ public class LinearRegression {
Matrix.Elemental.pow(normalized,2));
}
/**
* Calculates the gradiant of the current provided theta.
*/
protected static double calculateGradiant(double[][] x, double[] y, double[] theta){
int m = y.length; // number of training examples
double[] hypothesis = calculateHypothesis(x, theta);
double[] normalized = Matrix.subtract(hypothesis, y);
return 1/m * Matrix.sum(
Matrix.Elemental.multiply(Matrix.transpose(x), normalized));
}
/**
* Will try to find the best theta value.
*/
public static double[] gradientDescent(double[][] x, double[] y, double[] theta, double alpha){
double[] newTheta = theta.clone();
double gradient;
for (int i=0; (gradient = calculateGradiant(x, y, newTheta)) != 0; i++) {
logger.fine("Gradient Descent iteration " + i + ", gradiant: " + gradient);
newTheta = gradientDescentIteration(x, y, newTheta, alpha);
}
return newTheta;
}
/**
* Gradient Descent algorithm
* <br /><br />