Skip to content

Latest commit

 

History

History
304 lines (205 loc) · 8.54 KB

README.md

File metadata and controls

304 lines (205 loc) · 8.54 KB

Generalized Linear Model implementation in Java

Package implements the generalized linear model in Java

Build Status Coverage Status

GLM

Install

Add the following to dependencies of your pom file:

<dependency>
  <groupId>com.github.chen0040</groupId>
  <artifactId>java-glm</artifactId>
  <version>1.0.6</version>
</dependency>

Features

The current implementation of GLM supports as many distribution families as glm package in R:

  • Normal
  • Exponential
  • Gamma
  • InverseGaussian
  • Poisson
  • Bernouli
  • Binomial
  • Categorical
  • Multinomial

For the solvers, the current implementation of GLM supports a number of variants of the iteratively re-weighted least squares estimation algorithm:

  • IRLS
  • IRLS with QR factorization
  • IRLS with SVD factorization

Usage

Step 1: Create and train the glm against the training data in step 1

Suppose you want to create logistic regression model from GLM and train the logistic regression model against the data frame

import com.github.chen0040.glm.solvers.Glm;
import com.github.chen0040.glm.enums.GlmSolverType;

trainingData = loadTrainingData();

Glm glm = Glm.logistic();
glm.setSolverType(GlmSolverType.GlmIrls);
glm.fit(trainingData);

The "trainingData" is a data frame (Please refers to this link on how to create a data frame from file or from scratch)

The line "Glm.logistic()" create the logistic regression model, which can be easily changed to create other regression models (For example, calling "Glm.linear()" create a linear regression model)

The line "glm.fit(..)" performs the GLM training.

Step 2: Use the trained regression model to predict on new data

The trained glm can then run on the testing data, below is a java code example for logistic regression:

testingData = loadTestingData();
for(int i = 0; i < testingData.rowCount(); ++i){
    boolean predicted = glm.transform(testingData.row(i)) > 0.5;
    boolean actual = frame.row(i).target() > 0.5;
    System.out.println("predicted(Irls): " + predicted + "\texpected: " + actual);
}

The "testingData" is a data frame

The line "glm.transform(..)" perform the regression

Sample code

Sample code for linear regression

The sample code below shows the linear regression example

DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
      .newInput("x1")
      .newInput("x2")
      .newOutput("y")
      .end();

// y = 4 + 0.5 * x1 + 0.2 * x2
Sampler.DataSampleBuilder sampler = new Sampler()
      .forColumn("x1").generate((name, index) -> randn() * 0.3 + index)
      .forColumn("x2").generate((name, index) -> randn() * 0.3 + index * index)
      .forColumn("y").generate((name, index) -> 4 + 0.5 * index + 0.2 * index * index + randn() * 0.3)
      .end();

DataFrame trainingData = schema.build();

trainingData = sampler.sample(trainingData, 200);

System.out.println(trainingData.head(10));

DataFrame crossValidationData = schema.build();

crossValidationData = sampler.sample(crossValidationData, 40);

Glm glm = Glm.linear();
glm.setSolverType(GlmSolverType.GlmIrlsQr);
glm.fit(trainingData);

for(int i = 0; i < crossValidationData.rowCount(); ++i){
 double predicted = glm.transform(crossValidationData.row(i));
 double actual = crossValidationData.row(i).target();
 System.out.println("predicted: " + predicted + "\texpected: " + actual);
}

System.out.println("Coefficients: " + glm.getCoefficients());

Sample code for logistic regression

The sample code below performs binary classification using logistic regression:

InputStream inputStream = new FileInputStream("heart_scale.txt");
DataFrame dataFrame = DataQuery.libsvm().from(inputStream).build();

for(int i=0; i < dataFrame.rowCount(); ++i){
 DataRow row = dataFrame.row(i);
 String targetColumn = row.getTargetColumnNames().get(0);
 row.setTargetCell(targetColumn, row.getTargetCell(targetColumn) == -1 ? 0 : 1); // change output from (-1, +1) to (0, 1)
}

TupleTwo<DataFrame, DataFrame> miniFrames = dataFrame.shuffle().split(0.9);
DataFrame trainingData = miniFrames._1();
DataFrame crossValidationData = miniFrames._2();

Glm algorithm = Glm.logistic();
algorithm.setSolverType(GlmSolverType.GlmIrlsQr);
algorithm.fit(trainingData);

double threshold = 1.0;
for(int i = 0; i < trainingData.rowCount(); ++i){
 double prob = algorithm.transform(trainingData.row(i));
 if(trainingData.row(i).target() == 1 && prob < threshold){
    threshold = prob;
 }
}
logger.info("threshold: {}",threshold);


BinaryClassifierEvaluator evaluator = new BinaryClassifierEvaluator();

for(int i = 0; i < crossValidationData.rowCount(); ++i){
 double prob = algorithm.transform(crossValidationData.row(i));
 boolean predicted = prob > 0.5;
 boolean actual = crossValidationData.row(i).target() > 0.5;
 evaluator.evaluate(actual, predicted);
 System.out.println("probability of positive: " + prob);
 System.out.println("predicted: " + predicted + "\tactual: " + actual);
}

evaluator.report();

Sample code for multi-class classification

The sample code below perform multi class classification using the logistic regression model as the generator

InputStream irisStream = FileUtils.getResource("iris.data");
DataFrame irisData = DataQuery.csv(",")
      .from(irisStream)
      .selectColumn(0).asNumeric().asInput("Sepal Length")
      .selectColumn(1).asNumeric().asInput("Sepal Width")
      .selectColumn(2).asNumeric().asInput("Petal Length")
      .selectColumn(3).asNumeric().asInput("Petal Width")
      .selectColumn(4).asCategory().asOutput("Iris Type")
      .build();

TupleTwo<DataFrame, DataFrame> parts = irisData.shuffle().split(0.9);

DataFrame trainingData = parts._1();
DataFrame crossValidationData = parts._2();

System.out.println(crossValidationData.head(10));

OneVsOneGlmClassifier multiClassClassifier = Glm.oneVsOne(Glm::logistic);
multiClassClassifier.fit(trainingData);

ClassifierEvaluator evaluator = new ClassifierEvaluator();

for(int i=0; i < crossValidationData.rowCount(); ++i) {
 String predicted = multiClassClassifier.classify(crossValidationData.row(i));
 String actual = crossValidationData.row(i).categoricalTarget();
 System.out.println("predicted: " + predicted + "\tactual: " + actual);
 evaluator.evaluate(actual, predicted);
}

evaluator.report();

Background on GLM

Introduction

GLM is generalized linear model for exponential family of distribution model b = g(a). g(a) is the inverse link function.

Therefore, for a regressions characterized by inverse link function g(a), the regressions problem be formulated as we are looking for model coefficient set x in

$$g(A * x) = b + e$$

And the objective is to find x such for the following objective:

$$min (g(A * x) - b).transpose * W * (g(A * x) - b)$$

Suppose we assumes that e consist of uncorrelated naive variables with identical variance, then W = sigma^(-2) * I, and The objective

$$min (g(A * x) - b) * W * (g(A * x) - b).transpose$$

is reduced to the OLS form:

$$min || g(A * x) - b ||^2$$

Iteratively Re-weighted Least Squares estimation (IRLS)

In regressions, we tried to find a set of model coefficient such for:

$$A * x = b + e$$

A * x is known as the model matrix, b as the response vector, e is the error terms.

In OLS (Ordinary Least Square), we assumes that the variance-covariance

$$matrix V(e) = sigma^2 * W$$

, where: W is a symmetric positive definite matrix, and is a diagonal matrix sigma is the standard error of e

In OLS (Ordinary Least Square), the objective is to find x_bar such that e.transpose * W * e is minimized (Note that since W is positive definite, e * W * e is alway positive) In other words, we are looking for x_bar such as (A * x_bar - b).transpose * W * (A * x_bar - b) is minimized

Let

$$y = (A * x - b).transpose * W * (A * x - b)$$

Now differentiating y with respect to x, we have

$$dy / dx = A.transpose * W * (A * x - b) * 2$$

To find min y, set dy / dx = 0 at x = x_bar, we have

$$A.transpose * W * (A * x_bar - b) = 0$$

Transform this, we have

$$A.transpose * W * A * x_bar = A.transpose * W * b$$

Multiply both side by (A.transpose * W * A).inverse, we have

$$x_bar = (A.transpose * W * A).inverse * A.transpose * W * b$$

This is commonly solved using IRLS

The implementation of Glm based on iteratively re-weighted least squares estimation (IRLS)