K-Nearest Neighbor Algorithm in Java

Introduction

K-nearest neighbors (KNN) is a supervised machine learning algorithm used to perform classification and regression tasks.

KNN predicts the correct class for the test data by determining the distance between the test data and the training points. The algorithm selects K number of points which are closest to the test data. It then calculates the probability of the test data falling to the K classes. The class with the highest probability is chosen.

The KNN algorithm is referred to as a lazy learning algorithm because it does not learn a discriminative function from the training dataset but it instead memorizes the training dataset.

The Dataset

We will use a customer dataset to make a prediction for unknown data. The dataset to be used shows the age, income, and the product bought by a customer as shown below:

We will find the recommendations for a customer with an age of 58 years and an income of 51000.

We will first calculate the distance and then depending on the k value, we will be able to get the nearest k neighbors.

By default, the value of k is 1, but we can pass the value of k during the creation of the instance. If k is 1, it will show 1 product, that is, 1 nearest neighbor, while if k is 2, it will show 2 products, that is, 2 nearest neighbors.

Implementing K-Nearest Neighbors Algorithm in Java

We now want to implement the KNN algorithm in java using the above dataset. The dataset has been saved in a CSV file named customers.csv.

We will read the data from the csv file and load into GridDB. The data will then be pulled from GridDB for analysis with the algorithm

Import Packages

Let’s first import the packages that we will need to use:

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;


import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

Write Data into GridDB

We want to move the data from the CSV file into a GridDB container. First, let’s create the container schema as a static class:

 public static class Customers{
    
         @RowKey int customer;
         int age;
         Double income;
         String purchased_product;
        }

The above class is similar to a container or a SQL table with four columns.

Let’s establish a connection to GridDB. We will create a Properties instance using the specifics of our GridDB installation. Use the following code:

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

Change the above details to reflect the specifics of your GridDB installation.

Let us select theCustomers container since we will be using it:

Collection<String, Customers> coll = store.putCollection("col01", Customers.class);

An instance of the container Customers has been created and given the name coll. We will be using this instance to refer to the container.

Store the Data in GridDB

We can use the following java code to read data from the customers.csv file and store it into GridDB:

File file1 = new File("customers.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String customer = dataList[0];
                        String age = dataList[1];
                        String income = dataList[2];
                        String purchased_product = dataList[3];
                        
                        
                        Customers customers = new Customers();
    
                        customers.customer = Integer.parseInt(customer);
                        customers.age = Integer.parseInt(age);
                        customers.income = Double.parseDouble(income);
                        customers.purchased_product = purchased_product;
                        coll.append(customers);
                 }

We have created the customers object with the data about customers. The object has then been appended to the GridDB container.

Retrieve the Data from GridDB

It is now time to pull the data from the GridDB container. Use the following code:

Query<customers> query = coll.query("select *");
                RowSet</customers><customers> rs = query.fetch(false);
            RowSet res = query.fetch();</customers>

The select * statement helps us to query for all the data from the database container.

Build the Classifier

It’s now time to build a classifier using the KNN algorithm and the loaded data. Let’s import the libraries to be used for this:

import java.io.IOException;
import java.util.Enumeration;
import java.text.DecimalFormat;

import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.classifiers.lazy.IBk;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.converters.ArffLoader;

Let us now build the model and print out its statistics:

res.setClassIndex(res.numAttributes() - 1);
        Classifier cls = new IBk(1);        
        cls.buildClassifier(res);
    
        System.out.println(cls);
       
        Evaluation evaluation = new Evaluation(res);
        evaluation.evaluateModel(cls, res);
        
        System.out.println(evaluation.toSummaryString());
        System.out.println(evaluation.toClassDetailsString());
        System.out.println(evaluation.toMatrixString());

We specified the value of k when creating the IBk instance. The IBk instance takes an integer argument. If you pass to it a value of 1, it will find 1 nearest neighbor. If you pass 2, it will calcute 2 nearest neighbors. If you don’t pass any argument and you call it with the default constructor, it will calculate 1 nearest neighbor. In our case, we have passed a value of 1, so we will predict 1 nearest neighbor for the customer.

Compile and Run the Code

First, login as the gsadm user. Move your .java file to the bin folder of your GridDB located in the following path:

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

Next, run the following command to compile your .java file:

javac KNNeighbor.java

Run the .class file that is generated by running the following command:

java KNNeighbor

The KNN model will return 1 nearest neighbor for the customer.

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.