Introduction
Linear regression, also known as Simple Linear Regression is a regression algorithm that models the relationship between a dependent variable and one independent variable. A linear regression model shows a relationship that is linear or a sloped straight line, hence the name Simple Linear Regression.
In Linear Regression, the depedent variable must be a real or continuous value. However, you can measure the indepedent variable on continuous or categorical values.
The following are the two main objectives of the Simple Linear Regression algorithm:
-
Model the relationship between two given variables, for example, job experience and salary, income and expenditure, etc.
-
Predict new observations. For example, amount of revenue generated by a company based on investments in a year, weather forecasting based on temperature, etc.
Simple Linear Regression can be expressed using the following formula:
Y = a + bX
Where,
Y- is the dependent variable.
a- is the intercept of the Regression line.
b- is the slope of the line.
X- is the independent variable.
Implementing Linear Regression in Java
In this section, we will be discussing how to implement the Simple Linear Regression algorithm in Java. The dataset to be used has two variables, that is, salary
, which is the dependent variable, and experience
, which is the independent variable.
We will build a Linear Regression model from this dataset. We will then use the model to predict the salary of an individual based on their experience.
Store the Data in GridDB
The data has been stored in a CSV file named Salary_Data.csv
, but we need to write it into a GridDB container. Although we can still read the data directly from the CSV file, GridDB will improve the performance of our queries.
The dataset has two columns, namely yearsExperience
and salary
.
Let’s start by importing the libraries that we will need for this:
import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
Next, let’s create a static Java class to represent the GridDB container to be used:
public static class ExperienceSalary {
@RowKey float yearsExperience;
Double salary;
}
The above class is equivalent to a SQL table with two columns. The two variables represent the columns of the GridDB container.
Let us now connect to GridDB using Java. We will use the credentials of our GridDB installation as shown below:
Properties props = new Properties();
props.setProperty("notificationAddress", "239.0.0.1");
props.setProperty("notificationPort", "31999");
props.setProperty("clusterName", "defaultCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
We will be using the GridDB’s ExperienceSalary
container, so let’s select it:
Collection<String, ExperienceSalary> coll = store.putCollection("col01", ExperienceSalary.class);
We have created an instance of the container and given it the name coll
. We will use this instance name to refer to the container.
Let us now read the data from the Salary_Data.csv
file and write it into the GridDB container:
File file1 = new File("Salary_Data.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();
while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");
String yearsExperience = dataList[0];
String salary = dataList[1];
ExperienceSalary es = new ExperienceSalary();
es.yearsExperience = Float.parseFloat(yearsExperience);
es.salary = Double.parseDouble(salary);
coll.append(es);
}
See we have created an object of the static class and given it the name es
. The object has then been appended to the GridDB container.
Retrieve the Data from GridDB
We need to use the data to implement a Linear Regression model, so let’s pull it from the GridDB container:
Query<experiencesalary> query = coll.query("select *");
RowSet</experiencesalary><experiencesalary> rs = query.fetch(false);
RowSet res = query.fetch();</experiencesalary>
We have used the select *
query to pull all the data stored in the GridDB container.
Build a Linear Regression Model
Now that we have the data, it’s time to implement a Linear Regression model that will tell us whether there is a correlation between the two variables of the dataset.
First, let’s import the libraries that we will need to use:
import java.io.IOException;
import weka.classifiers.Evaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import java.io.BufferedReader;
import java.io.FileReader;
Next, let’s create a bufferred reader for the dataset. We will also create instances for the dataset:
BufferedReader bufferedReader
= new BufferedReader(
new FileReader(res));
// Create dataset instances
Instances datasetInstances
= new Instances(bufferedReader);
Let us now use the dataset to implement a Linear Regression classifier:
datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);
Classifier classifier = new weka.classifiers.functions.LinearRegression();
classifier.buildClassifier(datasetInstances);
Make a Prediction
Now that the model is ready, let us use it to make a prediction. We will be predicting the salary of the last instance in the dataset:
// Predicting the salary
Instance predSalary = datasetInstances.lastInstance();
double sal = classifier.classifyInstance(predSalary);
System.out.println("The salary is : "+sal);
Compile and Execute the Model
To compile and execute the model, you will need the Weka API. You can download it from the following URL:
http://www.java2s.com/Code/Jar/w/weka.htm
Next, login as the gsadm
user. Move your .java
file to the bin
folder of your GridDB located in the following path:
/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin
Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:
export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar
Next, run the following command to compile your .java
file:
javac -cp weka-3-7-0/weka.jar LinearRegression.java
Run the .class file that is generated by running the following command:
java -cp .:weka-3-7-0/weka.jar LinearRegression
The model returned 122017.00
, which is very close to the real salary for the instance, that is, 121872.00
.
If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.