Linear regression is a supervised machine learning technique that helps us to predict the value of a variable based on the value of another variable. The variable to be predicted is called the dependent
variable while the variable used for predicting other variables is called the independent variable
. It uses one or more independent variables to estimate the coefficients of the linear equation. Linear regression generates a straight line that minimizes the differences between the predicted and the expected output values.
In this article, we will be implementing a linear regression model that predicts the salary of an individual based on their years of experience using Java and GridDB.
Write the Data into GridDB
The data to be used shows the years of experience and salaries for different individuals.
We will store the data in GridDB as it offers many benefits such as fast query performance. Let us first import the Java libraries to help us accomplish this:
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
import java.util.*;
GridDB organizes data into containers, and each container can be represented as a static class in Java. Let us create a static class in Java to represent the container where the data will be stored:
static class SalaryData {
@RowKey int id;
double yearsExperience;
double salary;
}
We have created a GridDB container and given it the name SalaryData
. See it as an SQL table with 3 columns.
To write the data into GridDB, we should first establish a connection to the database. This requires us to create a Properties
file from the java.util
package and pass our GridDB credentials using the key:value
syntax:
Properties props = new Properties();
props.setProperty("notificationMember", "127.0.1.1:10001");
props.setProperty("clusterName", "myCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
We will be using the store
variable to interact with the database.
Now that we are connected, we can store the data in GridDB. Let us first define the data rows:
SalaryData row1 = new SalaryData();
row1.id=1;
row1.yearsExperience=1.1;
row1.salary=39343.00;
SalaryData row2 = new SalaryData();
row2.id=2;
row2.yearsExperience=1.3;
row2.salary=46205.00;
SalaryData row3 = new SalaryData();
row3.id=3;
row3.yearsExperience=1.5;
row3.salary=37731.00;
SalaryData row4 = new SalaryData();
row4.id=4;
row4.yearsExperience=2.0;
row4.salary=43525.00;
SalaryData row5 = new SalaryData();
row5.id=5;
row5.yearsExperience=2.2;
row5.salary=39891.00;
SalaryData row6 = new SalaryData();
row6.id=6;
row6.yearsExperience=2.9;
row6.salary=56642.00;
SalaryData row7 = new SalaryData();
row7.id=7;
row7.yearsExperience=3.0;
row7.salary=60150.00;
SalaryData row8 = new SalaryData();
row8.id=8;
row8.yearsExperience=3.2;
row8.salary=54445.00;
SalaryData row9 = new SalaryData();
row9.id=9;
row9.yearsExperience=3.2;
row9.salary=64445.00;
SalaryData row10 = new SalaryData();
row10.id=10;
row10.yearsExperience=3.7;
row10.salary=57189.00;
SalaryData row11 = new SalaryData();
row11.id=11;
row11.yearsExperience=3.9;
row11.salary=63218.00;
SalaryData row12 = new SalaryData();
row12.id=12;
row12.yearsExperience=4.0;
row12.salary=55794.00;
SalaryData row13 = new SalaryData();
row13.id=13;
row13.yearsExperience=4.0;
row13.salary=56957.00;
SalaryData row14 = new SalaryData();
row14.id=14;
row14.yearsExperience=4.1;
row14.salary=57081.00;
SalaryData row15 = new SalaryData();
row15.id=15;
row15.yearsExperience=4.5;
row15.salary=61111.00;
SalaryData row16 = new SalaryData();
row16.id=16;
row16.yearsExperience=4.9;
row16.salary=67938.00;
SalaryData row17 = new SalaryData();
row17.id=17;
row17.yearsExperience=5.1;
row17.salary=66029.00;
SalaryData row18 = new SalaryData();
row18.id=18;
row18.yearsExperience=5.3;
row18.salary=83088.00;
SalaryData row19 = new SalaryData();
row19.id=19;
row19.yearsExperience=5.9;
row19.salary=81363.00;
SalaryData row20 = new SalaryData();
row20.id=20;
row20.yearsExperience=6.0;
row20.salary=93940.00;
SalaryData row21 = new SalaryData();
row21.id=21;
row21.yearsExperience=6.8;
row21.salary=91738.00;
SalaryData row22 = new SalaryData();
row22.id=22;
row22.yearsExperience=7.1;
row22.salary=98273.00;
SalaryData row23 = new SalaryData();
row23.id=23;
row23.yearsExperience=7.9;
row23.salary=101302.00;
SalaryData row24 = new SalaryData();
row24.id=24;
row24.yearsExperience=8.2;
row24.salary=113812.00;
SalaryData row25 = new SalaryData();
row25.id=25;
row25.yearsExperience=8.7;
row25.salary=109431.00;
Let’s select the SalaryData
container where the data is to be stored:
Collection sd= store.putCollection("SalaryData", SalaryData.class);
We can now call the put()
function to help us flush the data into the database:
sd.put(row1);
sd.put(row2);
sd.put(row3);
sd.put(row4);
sd.put(row5);
sd.put(row6);
sd.put(row7);
sd.put(row8);
sd.put(row9);
sd.put(row10);
sd.put(row11);
sd.put(row12);
sd.put(row13);
sd.put(row14);
sd.put(row15);
sd.put(row16);
sd.put(row17);
sd.put(row18);
sd.put(row19);
sd.put(row20);
sd.put(row21);
sd.put(row22);
sd.put(row23);
sd.put(row24);
sd.put(row25);
Retrieve the Data
We should now retrieve the data from GridDB and use it to fit a machine learning model. We will write a TQL query that selects and returns all the data stored in the SalaryData
container:
Query query = sd.query("select *");
RowSet rs = query.fetch(false);
double x=0;
double y=0;
double[][] data = {{x},{y}};
System.out.println("Training dataset:");
while(rs.hasNext()) {
SalaryData sd1 = rs.next();
x=sd1.yearsExperience;
y=sd1.salary;
System.out.println(x +" "+ y);
double[][] d= {{x},{y}};
data=d.clone();
}
The select *
is a TQL query that fetches all the data stored in the SalaryData
container. We have also defined two variables, x
an y
to store the values of yearsExperiene
and salary
respectively when fetched from the database. The data has then been stored in an array named data
. The while
loop helps us to iterate over all the rows of the GridDB container.
Create Weka Instances
We now want to use the Weka machine learning library to fit a linear regression model. We have to convert our data into Weka instances. We will first create attributes for the dataset and store them in a FastVector data structure. Next, we will create Weka instances from the dataset.
Let’s first create the data structures for storing the attributes and the instances:
int numInstances = data[0].length;
FastVector atts = new FastVector();
List instances = new ArrayList();
We can now create afor
loop and use it to iterate over the data items while populating the FastVector data structure with the attributes:
for(int dim = 0; dim < 2; dim++)
{
Attribute current = new Attribute("Attribute" + dim, dim);
if(dim == 0)
{
for(int obj = 0; obj < numInstances; obj++)
{
instances.add(new SparseInstance(numInstances));
}
}
for(int obj = 0; obj < numInstances; obj++)
{
instances.get(obj).setValue(current, data[dim][obj]);
}
atts.addElement(current);
}
Instances newDataset = new Instances("Dataset", atts, instances.size());
The integer variable dim
generates a two dimensional data structure for storing the attributes. We have also created objects with the specific attributes to be used for building the model. These have been stored in an Instance variable named newDataset
.
Build a Linear Regression Model
The data instances are now ready, hence, we can use them to build a machine learning model. Let's first import the necessary libraries from Weka:
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
Let us set the class attribute for the dataset:
newDataset.setClassIndex(1);
The above line of code sets salary
as the class attribute for the dataset.
Next, we use the LinearRegression()
function of the Weka library to build a Linear Regression classifier.
for(Instance inst : instances)
newDataset.add(inst);
Classifier classifier = new weka.classifiers.functions.LinearRegression();
classifier.buildClassifier(newDataset);
We have created an instance of the function and given it the name classifier
. We have then called the buildClassifier()
function to build a classifier using our dataset.
Make a Prediction
Let's use our linear regression model to predict the salary of a person based on their years of experience. We will use the last instance of our dataset to make the prediction:
Instance pd = newDataset.lastInstance();
double value = classifier.classifyInstance(pd);
System.out.println(value);
pd
is an instance of the last instance of our dataset. The classifyInstance()
function predicts the value of salary
for the instance.
Execute the Model
To run the model, first download the Weka API from the following URL:
http://www.java2s.com/Code/Jar/w/weka.htm
Choose Weka version 3.7.0.
Set the class paths for the gridstore.jar
and weka-3-7-0.jar
files by running the following commands on the terminal:
export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar
export CLASSPATH=$CLASSPATH:/mnt/c/Users/user/Desktop/weka-3.7.0.jar
The above commands may change depending on the location of your files.
Compile your .java
file by running the following command:
javac Salary.java
Run the generated .class file using the following command:
java Salary
The model predicted a salary of 109431.0 for the person.
If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.