Linear Regression with Java

Introduction

Linear regression, also known as Simple Linear Regression is a regression algorithm that models the relationship between a dependent variable and one independent variable. A linear regression model shows a relationship that is linear or a sloped straight line, hence the name Simple Linear Regression.

In Linear Regression, the depedent variable must be a real or continuous value. However, you can measure the indepedent variable on continuous or categorical values.

The following are the two main objectives of the Simple Linear Regression algorithm:

• Model the relationship between two given variables, for example, job experience and salary, income and expenditure, etc.

• Predict new observations. For example, amount of revenue generated by a company based on investments in a year, weather forecasting based on temperature, etc.

Simple Linear Regression can be expressed using the following formula:

``Y = a + bX``

Where,

Y- is the dependent variable.

a- is the intercept of the Regression line.

b- is the slope of the line.

X- is the independent variable.

Implementing Linear Regression in Java

In this section, we will be discussing how to implement the Simple Linear Regression algorithm in Java. The dataset to be used has two variables, that is, `salary`, which is the dependent variable, and `experience`, which is the independent variable.

We will build a Linear Regression model from this dataset. We will then use the model to predict the salary of an individual based on their experience.

Store the Data in GridDB

The data has been stored in a CSV file named `Salary_Data.csv`, but we need to write it into a GridDB container. Although we can still read the data directly from the CSV file, GridDB will improve the performance of our queries.

The dataset has two columns, namely `yearsExperience` and `salary`.

Let’s start by importing the libraries that we will need for this:

``````import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;``````

Next, let’s create a static Java class to represent the GridDB container to be used:

``````public static class ExperienceSalary {
@RowKey float yearsExperience;
Double  salary;
}    ``````

The above class is equivalent to a SQL table with two columns. The two variables represent the columns of the GridDB container.

Let us now connect to GridDB using Java. We will use the credentials of our GridDB installation as shown below:

``````        Properties props = new Properties();
props.setProperty("clusterName", "defaultCluster");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);``````

We will be using the GridDB’s `ExperienceSalary` container, so let’s select it:

``Collection<String, ExperienceSalary> coll = store.putCollection("col01", ExperienceSalary.class);``

We have created an instance of the container and given it the name `coll`. We will use this instance name to refer to the container.

Let us now read the data from the `Salary_Data.csv` file and write it into the GridDB container:

``````File file1 = new File("Salary_Data.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();

while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");
String yearsExperience = dataList[0];
String salary = dataList[1];

ExperienceSalary es = new ExperienceSalary();

es.yearsExperience = Float.parseFloat(yearsExperience);
es.salary = Double.parseDouble(salary);

coll.append(es);
}``````

See we have created an object of the static class and given it the name `es`. The object has then been appended to the GridDB container.

Retrieve the Data from GridDB

We need to use the data to implement a Linear Regression model, so let’s pull it from the GridDB container:

``````Query<experiencesalary> query = coll.query("select *");
RowSet</experiencesalary><experiencesalary> rs = query.fetch(false);
RowSet res = query.fetch();</experiencesalary>``````

We have used the `select *` query to pull all the data stored in the GridDB container.

Build a Linear Regression Model

Now that we have the data, it’s time to implement a Linear Regression model that will tell us whether there is a correlation between the two variables of the dataset.

First, let’s import the libraries that we will need to use:

``````import java.io.IOException;

import weka.classifiers.Evaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

Next, let’s create a bufferred reader for the dataset. We will also create instances for the dataset:

``````            BufferedReader bufferedReader

// Create dataset instances
Instances datasetInstances

Let us now use the dataset to implement a Linear Regression classifier:

``````datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);

Classifier classifier = new weka.classifiers.functions.LinearRegression();

classifier.buildClassifier(datasetInstances);``````

Make a Prediction

Now that the model is ready, let us use it to make a prediction. We will be predicting the salary of the last instance in the dataset:

``````// Predicting the salary
Instance predSalary = datasetInstances.lastInstance();
double sal = classifier.classifyInstance(predSalary);
System.out.println("The salary is : "+sal);``````

Compile and Execute the Model

To compile and execute the model, you will need the Weka API. You can download it from the following URL:

http://www.java2s.com/Code/Jar/w/weka.htm

Next, login as the `gsadm` user. Move your `.java` file to the `bin` folder of your GridDB located in the following path:

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:

``export CLASSPATH=\$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar``

Next, run the following command to compile your `.java` file:

``javac -cp weka-3-7-0/weka.jar LinearRegression.java``

Run the .class file that is generated by running the following command:

``java -cp .:weka-3-7-0/weka.jar LinearRegression``

The model returned `122017.00`, which is very close to the real salary for the instance, that is, `121872.00`.