Predicting Purchasing Habits with the Random Forest Algorithm

Random Forest is a supervised machine learning algorithm used to solve classification and regression problems. The algorith uses the concept of ensemble learning, which is the process of combining many classifiers to solve a complex problem and improve the performance of a model.

A Random Forest classifier contains a number of decision trees for various subsets of a particular dataset and uses the average to improve the prediction accuracy of the dataset. Random Forest doesn’t rely on one decision tree, but it takes the prediction from each tree and makes the final prediction based on the majority votes of predictions.

A higher number of forests in the forest improves the accuracy of the algorithm and reduces chances of overfitting. In this article, we will be using the Random Forest Algorithm and GridDB in Java to predict whether a user will buy a SUV or not based on their gender, age, and salary.

Store the Data in GridDB

The data to be used shows the gender, age, salary, and whether the user purchased the SUV or not. This data has been stored in a CSV file named “suv.csv”. We want to store the data in GridDB and enjoy its benefits including improved query performance.

Let’s first import the set of libraries to be used for this:

import java.io.File;
import java.util.Scanner;
import java.io.IOException;
import java.util.Properties;
import java.util.Collection;


import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStoreFactory;

Let’s now create a static Java class to represent the GridDB container where we will store the data:

public static class SuvData {
     @RowKey int user_id;
     String gender;
     int age;
     int estimatedSalary;
     int purchased;
    } 

Each of the 5 columns of the GridDB container has been represented as a variable in the above class.

Let’s write a Java code to help us connect to the GridDB container. We will use the credentials of our GridDB installation:

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

Now that we have established a connection to GridDB, let’s select the container that we will be working with. The container is named SuvData:

Collection<String, SuvData> coll = store.putCollection("col01", SuvData.class);

We have created an instance of the container and given it the name coll. We will be using this name to refer to the container.

It’s now time to write data from the suv.csv file into the GridDB container:

 File file1 = new File("suv.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String user_id = dataList[0];
                        String gender = dataList[1];
                        String age = dataList[2];
            String estimatedSalary = dataList[3];
            String purchased = dataList[4];

                        SuvData sd = new SuvData();
                        sd.user_id = Integer.parseInt(user_id);
            sd.gender = gender;
            sd.age = Integer.parseInt(age);
            sd.estimatedSalary = Integer.parseInt(estimatedSalary);
                        sd.purchased = Integer.parseInt(purchased);
              
                        coll.append(sd);
                 }

The code will add the data into the specified GridDB container.

Retrieve the Data

We want to use the data to train a Random Forest machine learning model. So, we have to retrieve it. The following code can help us to retrieve the data from the GridDB container:

Query<suvdata> query = coll.query("select *");
                RowSet</suvdata><suvdata> rs = query.fetch(false);
            RowSet res = query.fetch();</suvdata>

We have used the select * statement to retrieve all the data stored in the GridDB container.

Implement a Random Forest Classifer

We can now use the data to create a Random Forest classifier. We will use the Random Forest classifier provided by Weka. Let’s begin by importing the libraries to help us achieve this:

import java.io.File;
import weka.core.Instances;
import weka.filters.Filter;
import java.io.IOException;
import java.io.FileReader;
import java.io.BufferedReader;
import weka.core.Instance;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.converters.ArffLoader;
import weka.classifiers.trees.RandomForest;
import weka.filters.unsupervised.attribute.StringToWordVector;

Let us store the data in a memory buffer and transform it into instances:

BufferedReader bufferedReader = new BufferedReader(
        new FileReader(res));
 
        // Create dataset instances
        Instances datasetInstances
        = new Instances(bufferedReader);

We can build a Random Forest classifier of 10 trees using the dataset. Let’s spare the last instance of the dataset and use it to make a prediction:

    datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);

        RandomForest forest=new RandomForest();
        forest.setNumTrees(10);
        
        forest.buildClassifier(datasetInstances);

Let’s now evaluate the model to get its summary statistics:

Evaluation eval = new Evaluation(datasetInstances);
        eval.evaluateModel(forest, datasetInstances);

                System.out.println("Random Forest Classifier Evaluation Summary");
        System.out.println(eval.toSummaryString());
        System.out.print(" The expression for the input data as per algorithm is: ");
        System.out.println(forest);

Make a Prediction

We can now use the last instance of the dataset to make a prediction. The classifyInstance() function of the Weka library will help us know whether the user will buy the SUV or not. This is shown below:

Instance pred = datasetInstances.lastInstance();
        double answer = forest.classifyInstance(pred);
        System.out.println(answer);

Compile and Run the Model

We will use the Weka API to compile and run the above model. Download the API from the following URL:

http://www.java2s.com/Code/Jar/w/weka.htm

Next, login as the gsadm user. Move your .java file to the bin folder of your GridDB located in the following path:

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Run the following command on your Linux terminal to set the path for the gridstore.jar file:

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

Run the following command to compile your .java file:

javac -cp weka-3-7-0/weka.jar RandomForestAlgorithm.java

Run the .class file that is generated by running the following command:

java -cp .:weka-3-7-0/weka.jar RandomForestAlgorithm

The prediction shows that the user will purchase the SUV.

One Comment

  1. jhon

    quisiera poder tener una mejor guia para desarrollar un algoritmo random forest en java utilizando gridDB en mi pc windonws

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.