Linear Regression with Java

Introduction

Linear regression, also known as Simple Linear Regression is a regression algorithm that models the relationship between a dependent variable and one independent variable. A linear regression model shows a relationship that is linear or a sloped straight line, hence the name Simple Linear Regression.

In Linear Regression, the depedent variable must be a real or continuous value. However, you can measure the indepedent variable on continuous or categorical values.

The following are the two main objectives of the Simple Linear Regression algorithm:

  • Model the relationship between two given variables, for example, job experience and salary, income and expenditure, etc.

  • Predict new observations. For example, amount of revenue generated by a company based on investments in a year, weather forecasting based on temperature, etc.

Simple Linear Regression can be expressed using the following formula:

Y = a + bX

Where,

Y- is the dependent variable.

a- is the intercept of the Regression line.

b- is the slope of the line.

X- is the independent variable.

Implementing Linear Regression in Java

In this section, we will be discussing how to implement the Simple Linear Regression algorithm in Java. The dataset to be used has two variables, that is, salary, which is the dependent variable, and experience, which is the independent variable.

We will build a Linear Regression model from this dataset. We will then use the model to predict the salary of an individual based on their experience.

Store the Data in GridDB

The data has been stored in a CSV file named Salary_Data.csv, but we need to write it into a GridDB container. Although we can still read the data directly from the CSV file, GridDB will improve the performance of our queries.

The dataset has two columns, namely yearsExperience and salary.

Let’s start by importing the libraries that we will need for this:

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;


import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

Next, let’s create a static Java class to represent the GridDB container to be used:

public static class ExperienceSalary {
     @RowKey float yearsExperience;
     Double  salary; 
}    

The above class is equivalent to a SQL table with two columns. The two variables represent the columns of the GridDB container.

Let us now connect to GridDB using Java. We will use the credentials of our GridDB installation as shown below:

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

We will be using the GridDB’s ExperienceSalary container, so let’s select it:

Collection<String, ExperienceSalary> coll = store.putCollection("col01", ExperienceSalary.class);

We have created an instance of the container and given it the name coll. We will use this instance name to refer to the container.

Let us now read the data from the Salary_Data.csv file and write it into the GridDB container:

File file1 = new File("Salary_Data.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String yearsExperience = dataList[0];
                        String salary = dataList[1];
                                                
                        
                        ExperienceSalary es = new ExperienceSalary();
    
                        es.yearsExperience = Float.parseFloat(yearsExperience);
                        es.salary = Double.parseDouble(salary);
                        
                        coll.append(es);
                 }

See we have created an object of the static class and given it the name es. The object has then been appended to the GridDB container.

Retrieve the Data from GridDB

We need to use the data to implement a Linear Regression model, so let’s pull it from the GridDB container:

Query<experiencesalary> query = coll.query("select *");
                RowSet</experiencesalary><experiencesalary> rs = query.fetch(false);
            RowSet res = query.fetch();</experiencesalary>

We have used the select * query to pull all the data stored in the GridDB container.

Build a Linear Regression Model

Now that we have the data, it’s time to implement a Linear Regression model that will tell us whether there is a correlation between the two variables of the dataset.

First, let’s import the libraries that we will need to use:

import java.io.IOException;

import weka.classifiers.Evaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader; 

import java.io.BufferedReader;
import java.io.FileReader;

Next, let’s create a bufferred reader for the dataset. We will also create instances for the dataset:

            BufferedReader bufferedReader
                = new BufferedReader(
                    new FileReader(res));
 
            // Create dataset instances
            Instances datasetInstances
                = new Instances(bufferedReader);

Let us now use the dataset to implement a Linear Regression classifier:

datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);

            Classifier classifier = new weka.classifiers.functions.LinearRegression();
        
        classifier.buildClassifier(datasetInstances);

Make a Prediction

Now that the model is ready, let us use it to make a prediction. We will be predicting the salary of the last instance in the dataset:

// Predicting the salary 
        Instance predSalary = datasetInstances.lastInstance();
        double sal = classifier.classifyInstance(predSalary);
        System.out.println("The salary is : "+sal);

Compile and Execute the Model

To compile and execute the model, you will need the Weka API. You can download it from the following URL:

http://www.java2s.com/Code/Jar/w/weka.htm

Next, login as the gsadm user. Move your .java file to the bin folder of your GridDB located in the following path:

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

Next, run the following command to compile your .java file:

javac -cp weka-3-7-0/weka.jar LinearRegression.java

Run the .class file that is generated by running the following command:

java -cp .:weka-3-7-0/weka.jar LinearRegression

The model returned 122017.00, which is very close to the real salary for the instance, that is, 121872.00.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.