How to Implement a Random Forest Algorithm in Java

Random forest is a machine learning algorithm used for classification and other purposes. In this article, we describe the implementation of a random forest algorithm in Java to predict the class of iris plants. For this purpose, we begin by defining the requirements and importing the packages. Then, we present the Iris dataset and implement the random forest algorithm using the Weka library. We retrieve the data from the file and store it in GridDB. Then, we retrieve the data and execute the random forest algorithm. Finally, we discuss the results.

Requirements

The code presented in the following sections makes use of GridDB to store and retrieve the dataset. For this reason, please download GridDB, create a node, and join a cluster. Do not forget to update some of your environment variables for GridDB in Ubuntu 18.04

export GS_HOME=$PWD
export GS_LOG=$PWD/log
export PATH=${PATH}:$GS_HOME/bin

We do the same for the gridstore.jar package:

export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar 

For the random forest algorithm, we download the Weka library and update our environment variable for Java to locate it:

export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar

At the level of the Java code, we connect to the GridDB cluster, we create a GridDB store and container. We also need to define the container schema, a collection, and columns. Here is the code to achieve this:

  // Manage connection to GridDB
            Properties properties = new Properties();
            properties.setProperty("notificationAddress", "239.0.0.1");
            properties.setProperty("notificationPort", "31999");
            properties.setProperty("clusterName", "cluster");
            properties.setProperty("database", "public");
            properties.setProperty("user", "admin");
            properties.setProperty("password", "admin");
//Get Store and Container
            GridStore store = GridStoreFactory.getInstance().getGridStore(properties);
 
            store.getContainer("newContainer");
 
            String containerName = "mContainer";
       
// Define container schema and columns
        ContainerInfo containerInfo = new ContainerInfo();
        List<columninfo> columnList = new ArrayList</columninfo><columninfo>();
        columnList.add(new ColumnInfo("key", GSType.INTEGER));
        columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
        columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
        columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
        columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
        columnList.add(new ColumnInfo("irisclass", GSType.STRING));
 
        containerInfo.setColumnInfoList(columnList);
        containerInfo.setRowKeyAssigned(true);
        Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
        List<row> rowList = new ArrayList</row><row>();
</row></columninfo>

In this application, we use classes from 4 main packages:

  • java.util: contains utility classes like ArrayList, List, Random, and Scanner
  • java.io: allows input/output operations to read the dataset from a file.
  • com.toshiba.mwcloud.gs: used for data interactions with GridDB.
  • weka.classifier.trees: contains classes to implement a random forest algorithm.

Here is the code to import the packages. These classes will be used on various occasions in the following sections.

// ---------- Java Util ---------
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import java.util.Scanner;
 
// ---------- Java IO ---------
import java.io.IOException;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;
 
// ---------- GridDB ---------
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;
 
 
//----------- Weka ---------
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.Evaluation;

The dataset

The dataset used in this article is the Iris dataset obtained from the Weka default datasets available when you download the tool. The dataset contains 150 instances of data collected from samples of the Iris plant. It contains 5 attributes that are the sepal length, the sepal width, the petal length, the petal width, and the class of the plant (Iris Setosa, Iris Versicolour, Iris Virginica). Each one of the attributes corresponds respectively to each of the columns of the dataset. Here is an extract:

5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica

Before building the random forest algorithm, we will begin by obtaining the dataset from the CSV file, and writing it into GridDB.

In the following code, we begin by opening the CSV file containing the dataset. We make use of a Scanner to iterate through the file and extract its content. The file is delimited with a new line \n for new lines, and with a comma , for the columns.

In each iteration of the while loop, we place the data in an array of strings, that are cast to their corresponding datatypes and placed into a dedicated variable.

// Handling Dataset and storage to GridDB
            File data = new File("/home/ubuntu/griddb/gsSample/iris.csv");
            Scanner sc = new Scanner(data);  
            sc.useDelimiter("\n");
            while (sc.hasNext())  //returns a boolean value  
            {  
               Row row = collection.createRow();
 
            String line = sc.next();
            String columns[] = line.split(",");
            float slenght = Float.parseFloat(columns[0]);
            float swidth = Float.parseFloat(columns[1]);
            float plenght = Float.parseFloat(columns[2]);
            float pwidth = Float.parseFloat(columns[3]);
            String irisclass = columns[4];  
            }

In later sections, we write the data obtained from the dataset into GridDB.

The random forest algorithm is applied to predict the Iris subspecies according to the sepal and petal dimensions of the Iris flower. We discuss its implementation in the next session.

Implementing a Random Forest Algorithm in Java

The use of a Random Forest algorithm on the Iris dataset is intended to improve the accuracy to predict the iris subspecies, compared to decision trees. In fact, the random forest algorithm is presented as an improvement of decision trees, as it uses more complex algorithms to generate predictions. A random forest, as the name might suggest, makes use of multiple decision trees to build a result, so as to be more representative. The difference between the two algorithms is that decision trees use the greedy algorithm to make a decision at each node, while the random forest takes a random subset from the input data to make a decision. We will observe that this difference will actually result in an improvement in the accuracy of predicting the Iris subspecies. To implement the Random forest algorithm we use the Weka library. The Random Forest algorithm is located under the weka.classifiers.trees package, which contains other classification algorithms like J48 decision tree mentioned in other articles.

Write Data into GridDB

If we watch closely in previous sections, we define a Row and a List<Row> that are GridDB datatypes used to write the data fetched from the dataset into GridDB. We achieve this with the following code:

row.setInteger(0,i);
row.setFloat(1,slenght );
row.setFloat(2, swidth);
row.setFloat(3, plenght);
row.setFloat(4, pwidth);
row.setString(5, irisclass);
 
rowList.add(row);

Store the Data in GridDB

Storing the data in GridDB is achieved thanks to the mapping of the columns from the dataset with the container schema defined for GridDB. If we recall the following code, we can observe a direct mapping of the columns. First, we observe a key column of type integer, followed by four float variables, that are each one of the attributes, and finally, the iris class, which is a string. The code used in the previous sections allows inserting data in each one of the columns.

 columnList.add(new ColumnInfo("key", GSType.INTEGER));
 columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
 columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
 columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
 columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
 columnList.add(new ColumnInfo("irisclass", GSType.STRING));

When the following line of code is executed, the data is stored in GridDB:

rowList.add(row);

Retrieve the Data from GridDB

To verify that the data was correctly stored in GridDB, we perform a query that retrieves all the records in the database. This is the code that performs this operation:

// Retrieving data from GridDB
        Container container = store.getContainer(containerName);
        if ( container == null ){
            throw new Exception("Container not found.");
        }
        Query<row> query = container.query("SELECT * ");
        RowSet</row><row> rowset = query.fetch();</row>

The data obtained is placed in the rowset variable, that we can easily use to retrieve or print the data, like the following:

// Print GridDB data
        while ( rowset.hasNext() ) {
            Row row = rowset.next();
            float slenght = row.getFloat(0);
            float swidth = row.getFloat(1);
            float plenght = row.getFloat(2);
            float pwidth = row.getFloat(3);
            String irisclass = row.getString(4);
            System.out.println(" slenght=" + slenght + ", swidth=" + swidth + ", plenght=" + plenght +", pwidth=" + pwidth+", irisclass=" + irisclass);
        }

Build the Random Forest

The results achieved in this article are obtained by executing the random forest classifier with the following parameters:

-P 100 -I 100 -num-slots 1 -K 0 -M 1.0 -V 0.001 -S 1

To use this algorithm in Java, we begin by creating an object of the class RandomForest.

 RandomForest randomForest = new RandomForest();

Then, we have to specify the array of parameters of the random forest algorithm. We achieve this with the following lines of code:

String[] parameters = new String[14];
     
parameters[0] = "-P";
parameters[1] = "100";
parameters[2] = "-I";
parameters[3] = "100";
 parameters[4] = "-num-slots";
parameters[5] = "1";
parameters[6] = "-K";
parameters[7] = "0";
parameters[8] = "-M";
parameters[9] = "1.0";
parameters[10] = "-V";
parameters[11] = "0.001";
parameters[12] = "-S";
parameters[13] = "1";
   
randomForest.setOptions(parameters);

For more details about random forest parameters in Weka, please visit the official Java

of the Random Forest class.

After that, we can build the classifier with the training dataset. At this moment, we are ready to evaluate the algorithm thanks to the Evaluation class. We will then evaluate the model using cross-validation to obtain the results of the prediction, that we print to our command line. This is achieved with the following lines of code:

randomForest.setOptions(parameters);
 
randomForest.buildClassifier(datasetInstances);
 
Evaluation evaluation = new Evaluation(datasetInstances);
 
 
evaluation.crossValidateModel(randomForest, datasetInstances, numFolds, new Random(1));
 
System.out.println(evaluation.toSummaryString("\nResults\n======\n", true));

Now that our code is ready, we proceed to compile and run our Java code.

Compile and Run the Code

In the command line, we navigate to the GridDB folder and execute our commands. Our Java code is located in the randomForest.java file under the folder gsSample. First, we need to compile the file, and then we can execute it. This is achieved with the following commands:

javac gsSample/randomForest.java

java gsSample/randomForest.java

Conclusion & Results

Once we compile and run the code, the results below are printed to the command line. You also have the option to print only the Summary section. We can see other sections, like the “Run information” where we can see the context of the execution of the algorithm. The confusion matrix provides a detailed count of classified instances according to their class.

The use of a random forest algorithm to predict the Iris class reaches an accuracy of 95.3%, classifying 143 correct instances out of 150. In the previous article, using a decision tree, the accuracy was around 94%.

=== Run information ===

Scheme:       weka.classifiers.trees.RandomForest -P 100 -I 100 -num-slots 1 -K 0 -M 1.0 -V 0.001 -S 1
Relation:     iris
Instances:    150
Attributes:   5
              sepallength
              sepalwidth
              petallength
              petalwidth
              class
Test mode:    10-fold cross-validation

=== Classifier model (full training set) ===

RandomForest

Bagging with 100 iterations and base learner

weka.classifiers.trees.RandomTree -K 0 -M 1.0 -V 0.001 -S 1 -do-not-check-capabilities

Time taken to build model: 0.01 seconds

=== Stratified cross-validation ===
=== Summary ===

Correctly Classified Instances         143               95.3333 %
Incorrectly Classified Instances         7                4.6667 %
Kappa statistic                          0.93  
Mean absolute error                      0.0408
Root mean squared error                  0.1621
Relative absolute error                  9.19   %
Root relative squared error             34.3846 %
Total Number of Instances              150     

=== Detailed Accuracy By Class ===

                 TP Rate  FP Rate  Precision  Recall   F-Measure  MCC      ROC Area  PRC Area  Class
                 1,000    0,000    1,000      1,000    1,000      1,000    1,000     1,000     Iris-setosa
                 0,940    0,040    0,922      0,940    0,931      0,896    0,991     0,984     Iris-versicolor
                 0,920    0,030    0,939      0,920    0,929      0,895    0,991     0,982     Iris-virginica
Weighted Avg.    0,953    0,023    0,953      0,953    0,953      0,930    0,994     0,989     

=== Confusion Matrix ===

  a  b  c   <-- classified as
 50  0  0 |  a = Iris-setosa
  0 47  3 |  b = Iris-versicolor
  0  4 46 |  c = Iris-virginica

If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.