Introduction
K-nearest neighbors (KNN) is a supervised machine learning algorithm used to perform classification and regression tasks.
KNN predicts the correct class for the test data by determining the distance between the test data and the training points. The algorithm selects K
number of points which are closest to the test data. It then calculates the probability of the test data falling to the K
classes. The class with the highest probability is chosen.
The KNN algorithm is referred to as a lazy learning algorithm because it does not learn a discriminative function from the training dataset but it instead memorizes the training dataset.
The Dataset
We will use a customer dataset to make a prediction for unknown data. The dataset to be used shows the age, income, and the product bought by a customer as shown below:
We will find the recommendations for a customer with an age of 58 years and an income of 51000.
We will first calculate the distance and then depending on the k
value, we will be able to get the nearest k
neighbors.
By default, the value of k
is 1, but we can pass the value of k
during the creation of the instance. If k
is 1, it will show 1 product, that is, 1 nearest neighbor, while if k
is 2, it will show 2 products, that is, 2 nearest neighbors.
Implementing K-Nearest Neighbors Algorithm in Java
We now want to implement the KNN algorithm in java using the above dataset. The dataset has been saved in a CSV file named customers.csv
.
We will read the data from the csv file and load into GridDB. The data will then be pulled from GridDB for analysis with the algorithm
Import Packages
Let’s first import the packages that we will need to use:
import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
Write Data into GridDB
We want to move the data from the CSV file into a GridDB container. First, let’s create the container schema as a static class:
public static class Customers{
@RowKey int customer;
int age;
Double income;
String purchased_product;
}
The above class is similar to a container or a SQL table with four columns.
Let’s establish a connection to GridDB. We will create a Properties instance using the specifics of our GridDB installation. Use the following code:
Properties props = new Properties();
props.setProperty("notificationAddress", "239.0.0.1");
props.setProperty("notificationPort", "31999");
props.setProperty("clusterName", "defaultCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
Change the above details to reflect the specifics of your GridDB installation.
Let us select theCustomers
container since we will be using it:
Collection<String, Customers> coll = store.putCollection("col01", Customers.class);
An instance of the container Customers
has been created and given the name coll
. We will be using this instance to refer to the container.
Store the Data in GridDB
We can use the following java code to read data from the customers.csv
file and store it into GridDB:
File file1 = new File("customers.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();
while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");
String customer = dataList[0];
String age = dataList[1];
String income = dataList[2];
String purchased_product = dataList[3];
Customers customers = new Customers();
customers.customer = Integer.parseInt(customer);
customers.age = Integer.parseInt(age);
customers.income = Double.parseDouble(income);
customers.purchased_product = purchased_product;
coll.append(customers);
}
We have created the customers
object with the data about customers. The object has then been appended to the GridDB container.
Retrieve the Data from GridDB
It is now time to pull the data from the GridDB container. Use the following code:
Query<customers> query = coll.query("select *");
RowSet</customers><customers> rs = query.fetch(false);
RowSet res = query.fetch();</customers>
The select *
statement helps us to query for all the data from the database container.
Build the Classifier
It’s now time to build a classifier using the KNN algorithm and the loaded data. Let’s import the libraries to be used for this:
import java.io.IOException;
import java.util.Enumeration;
import java.text.DecimalFormat;
import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.classifiers.lazy.IBk;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.converters.ArffLoader;
Let us now build the model and print out its statistics:
res.setClassIndex(res.numAttributes() - 1);
Classifier cls = new IBk(1);
cls.buildClassifier(res);
System.out.println(cls);
Evaluation evaluation = new Evaluation(res);
evaluation.evaluateModel(cls, res);
System.out.println(evaluation.toSummaryString());
System.out.println(evaluation.toClassDetailsString());
System.out.println(evaluation.toMatrixString());
We specified the value of k
when creating the IBk instance. The IBk instance takes an integer argument. If you pass to it a value of 1, it will find 1 nearest neighbor. If you pass 2, it will calcute 2 nearest neighbors. If you don’t pass any argument and you call it with the default constructor, it will calculate 1 nearest neighbor. In our case, we have passed a value of 1, so we will predict 1 nearest neighbor for the customer.
Compile and Run the Code
First, login as the gsadm
user. Move your .java
file to the bin
folder of your GridDB located in the following path:
/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin
Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:
export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar
Next, run the following command to compile your .java
file:
javac KNNeighbor.java
Run the .class file that is generated by running the following command:
java KNNeighbor
The KNN model will return 1 nearest neighbor for the customer.
If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.