How to Implement a Decision Tree Algorithm in Java

The combination of statistical methods, machine learning, and artificial intelligence have allowed the development of various data mining techniques that serve as support systems in many fields. These methods have acquired the importance to become vital in different sectors and industries.

This article describes the implementation of a decision tree algorithm in Java. The article begins by describing the dataset used in this article. Later, the implementation of the algorithm using the Weka J48 decision tree is presented. After that, we describe different interactions with GridDB, which is our main database. In this case, we write the original dataset into GridDB, we store the data, and we retrieve it for its later use to build the decision tree.

Requirements

In this section we describe the requirements and configuration used in this article.

Weka 3.9: Download and place weka.jar file in the /usr/share/java/ path.

GridDB 4.6: After installation, a GridDB cluster has to be active.

Make sure to add the Weka library path to CLASSPATH. We perform the same operation for GridDB. Here are the corresponding command lines:

$ export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar
$ export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar

The dataset

For the purpose of this article we have chosen to use a dataset for the iris plant. This dataset is open source and available here.The dataset is composed of 149 entries with 5 attributes, that are described as follows:

  1. Sepal length in cm. Measures the sepal length of an Iris plant sample
  2. Sepal width in cm. Measures the sepal width of an Iris plant sample
  3. Petal length in cm. Measures the petal length of an Iris plant sample
  4. Petal width in cm. Measures the pepal width of an Iris plant sample
  5. Class: Indicates the Iris subspecies. The possible values are: Iris Setos, Iris Versicolour, Iris Virginica.

Here is an extract of the dataset:

4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor

The dataset is presented in the form of a .CSV file, that is readed, parsed and temporarily placed in variables for its future storage in a GridDB collection. The following code performs these operations:

// Handlig Dataset and storage to GridDB
            File data = new File("/home/ubuntu/griddb/gsSample/iris.csv");
            Scanner sc = new Scanner(data);  
            sc.useDelimiter("\n");
            while (sc.hasNext())  //returns a boolean value  
            {  
            Row row = collection.createRow();

            String line = sc.next();
            String columns[] = line.split(",");
            float slenght = Float.parseFloat(columns[0]);
            float swidth = Float.parseFloat(columns[1]);
            float plenght = Float.parseFloat(columns[2]);
            float pwidth = Float.parseFloat(columns[3]);
            String irisclass = columns[4];
            }

Do not forget to close your scanner when finished!

 sc.close();

The decision tree is implemented in order to predict the Iris subspecies according to the petal and sepal dimensions. In the following section, we describe the implementation of a decision tree in Java.

Implementing a Decision Tree Algorithm in Java

As mentioned in earlier sections, this article will use the J48 decision tree available at the Weka package. This class generates pruned or unpruned C4.5 decision trees. Let’s have a closer look at the implementation. We begin by importing the required packages for Weka:

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.J48;
import weka.classifiers.Evaluation;

The J48 decision tree algorithm has a series of attributes that can be fine tuned to match the dataset with the algorithm. In our case we set the following two options:

**-C . Pruning confidence**. Set confidence threshold for pruning. 

-M . Minimum number of instances. Set a minimum number of instances per leaf.

In our case, we set the pruning confidence to 0.25, and 30 as the minimum number of instances. We achieve this with the following code:

String[] options = new String[4];
        options[0] = "-C";
        options[1] = "0.25";
        options[2] = "-M";
        options[3] = "30";

Our decision tree is ready to be used! Now we will prepare the data and use it to build the algorithm.

Import packages

We use different classes from the java.util package. ArrayList and List are used to organize our data. The Properties class is used to pass to GridDB store instance the cluster connection parameters. Finally, the Random class is used to randomize the seed parameter in the cross validation phase of building the decision tree.

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;

Different packages are imported to connect and interact with GridDB.

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;

Finally, we import packages to interact with the file containing the original dataset.

import java.io.IOException;
import java.util.Scanner;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;

Write Data into GridDB

In order to write data into GridDB, we begin by setting the connection properties to our desired cluster. In the same way, we create our GridStore instance that will hold our container referenced by its name containerName.

Properties prop = new Properties();
            prop.setProperty("notificationAddress", "239.0.0.1");
            prop.setProperty("notificationPort", "31999");
            prop.setProperty("clusterName", "cluster");
            prop.setProperty("database", "public");
            prop.setProperty("user", "admin");
            prop.setProperty("password", "admin");
 
            GridStore store = GridStoreFactory.getInstance().getGridStore(prop);
            store.getContainer("newContainer");
            String containerName = "last";

Now we have successfully created our store and container, we can start defining the container schema, by setting the container information and the list of columns with their corresponding data types.

// Define ontainer schema and columns
        ContainerInfo containerInfo = new ContainerInfo();
        List<columninfo> columnList = new ArrayList</columninfo><columninfo>();
        columnList.add(new ColumnInfo("key", GSType.INTEGER));
        columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
        columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
        columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
        columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
        columnList.add(new ColumnInfo("irisclass", GSType.STRING));

        containerInfo.setColumnInfoList(columnList);
        containerInfo.setRowKeyAssigned(true);</columninfo>

GridDB offers two types of containers. In our case, we choose the Collection type used to manage general data. For this purpose, we create our collection, and a list of rows to organize our data. We will also instantiate the Row class, that will hold a data row in each iteration to obtain data from the original CSV file.

Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
List<row> rowList = new ArrayList</row><row>();
Row row = collection.createRow();

</row>

Store the Data in GridDB

At this moment, we recall the variables used to temporarily store the values obtained from the dataset, and insert them into each row. After that, we add each row into the row list defined in the previous section.

           row.setInteger(0,key);
            row.setFloat(1,slenght );
            row.setFloat(2, swidth);
            row.setFloat(3, plenght);
            row.setFloat(4, pwidth);
            row.setString(5, irisclass);
            rowList.add(row);
    

To store the data into GridDB, we use the following line of code:

collection.put(rowList);

Retrieve the Data from GridDB

In order to retrieve the data we have just stored in GridDB, we perform a TQL query, that in our case, will select all data in our container.

Query<row> query = container.query("SELECT * ");
        RowSet</row><row> rs = query.fetch();</row>

In case we would like to visualize our data:

while ( rs.hasNext() ) {
            Row row = rs.next();
            float slenght = row.getFloat(0);
            float swidth = row.getFloat(1);
            float plenght = row.getFloat(2);
            float pwidth = row.getFloat(3);
            String irisclass = row.getString(4);
            System.out.println(" slenght=" + slenght + ", swidth=" + swidth + ", plenght=" + plenght +", pwidth=" + pwidth+", irisclass=" + irisclass);
        }

Build the Decision Tree

Now outr data is stored in GridDB, we are ready to build the decision tree with this data. The Weka function that builds the classification tree, takes data Instances as parameters, so we will need to include the following lines of code in order to have our data suitable for Weka functions.

BufferedReader bufferedReader= new BufferedReader(new FileReader(res));
    Instances datasetInstances= new Instances(bufferedReader);

At this level, we simply proceed to build the classifier.

mytree.buildClassifier(datasetInstances);

Last but not least, we generate an evaluation for the classifier tree we have just built.

Evaluation eval = new Evaluation(datasetInstances);  eval.crossValidateModel(mytree, datasetInstances, 10, new Random(1));

We observe in the last line of code that we perform a cross validation to our data instances. This process will ensure that the dataset is split in different ways to obtain unbiased results, specially that we count with a limited dataset.

Finally, we print the evaluation summary:

    System.out.println(eval.toSummaryString("\n ****** J48 *****\n", true));

Compile and Run the Code

In our context, a .java file is located at the level of the gsSample default folder and contains the code described in this article. Navigate to your GridDB folder, and execute the following commands to compile and run the code:

~/griddb$ javac gsSample/Select.java
~/griddb$ java gsSample/Select.java

As we can observe, running the code does not need any additional cluster parameters in the command line, as these were hardcoded in our java file.

Conclusion

Here is an extract of the evaluation summary print in the previous sections:

---Registering Weka Editors---
Trying to add database driver (JDBC): jdbc.idbDriver - Error, not in CLASSPATH?

 ****** J48 *****

Correctly Classified Instances         141               94.6309 %
Incorrectly Classified Instances         8                5.3691 %
Kappa statistic                          0.9195
K&B Relative Info Score              13340.283  %
K&B Information Score                  211.4385 bits      1.4191 bits/instance
Class complexity | order 0             236.1698 bits      1.585  bits/instance
Class complexity | scheme             2179.3992 bits     14.6268 bits/instance
Complexity improvement     (Sf)      -1943.2293 bits    -13.0418 bits/instance
Mean absolute error                      0.0578
Root mean squared error                  0.1831
Relative absolute error                 13.0031 %
Root relative squared error             38.8358 %
Total Number of Instances              149

As we highlight, the decision tree has been able to reach an accuracy of 94.6% on classifying the iris plants into their classes. Please visit the official documentation for Weka J48 in order to learn more about the results and their interpretation.

Do not forget to close the TQL query, the container, and the GridDB store:

 query.close();
 container.close();
 store.close();

If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.