JavaによるK-Nearest Neighborアルゴリズム

はじめに

K-nearest neighbors (KNN) は分類や回帰のタスクを実行するために用いられる教師あり機械学習アルゴリズムです。

KNNは、テストデータと学習点の距離を求めることで、テストデータに対して正しいクラスを予測します。このアルゴリズムでは、テストデータに最も近い点をK個選択します。そして、テストデータがK個のクラスに分類される確率を計算します。そして、最も高い確率を持つクラスが選択されます。

KNN アルゴリズムは、学習データから識別関数を学習するのではなく、学習データセットを記憶するため、遅延学習アルゴリズム と呼ばれます。

データセット

未知のデータに対して予測を行うために、顧客データセットを使用します。使用するデータセットは、以下のように、顧客の年齢、収入、購入した商品などを示しています。

年齢58歳、収入51000円のお客様に対するオススメ商品を見つけようと思います。

まず距離を計算し、次に k の値に応じて最も近い k 人の近傍を取得することができるようにします。

デフォルトでは k の値は 1 ですが、インスタンス生成時に k の値を渡すことができます。k が 1 の場合は 1 個の製品、つまり 1 件の最近傍が表示され、2 の場合は 2 個の製品、つまり 2 件の最近傍が表示されます。

K-最近傍アルゴリズムの Java での実装

上記のデータセットを用いて、JavaでKNNアルゴリズムを実装してみます。データセットは customers.csv という名前のCSVファイルに保存されています。

csvファイルからデータを読み込み、GridDBにロードします。その後、GridDBからデータを取り出し、アルゴリズムを使って分析します。

パッケージのインポート

まず、使用する必要のあるパッケージをインポートしましょう。

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

データをGridDBに書き込む

CSVファイルからGridDBコンテナにデータを移動させたいと思います。まず、コンテナ・スキーマを静的クラスとして作成しましょう。


public static class Customers{

        @RowKey int customer;
        int age;
        Double income;
        String purchased_product;
}

上記のクラスは、コンテナや4つのカラムを持つSQLテーブルに似ています。

それでは、GridDBへの接続を確立してみましょう。GridDBをインストールしたときの指定情報に基づいてPropertiesのインスタンスを作成します。以下のコードを使用します。


        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

GridDB のインストール環境に合わせて、上記の内容を変更します。

ここでは、Customersコンテナを使用するため、これを選択します。

Collection<String, Customers> coll = store.putCollection("col01", Customers.class);

コンテナ Customers のインスタンスが作成され、coll という名前が付けられました。このインスタンスを使って、コンテナを参照することにします。

データをGridDBに格納する

以下のJavaコードで、customers.csvファイルからデータを読み込んで、GridDBに格納することができます。


                File file1 = new File("customers.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String customer = dataList[0];
                        String age = dataList[1];
                        String income = dataList[2];
                        String purchased_product = dataList[3];
                        
                        Customers customers = new Customers();
                        customers.customer = Integer.parseInt(customer);
                        customers.age = Integer.parseInt(age);
                        customers.income = Double.parseDouble(income);
                        customers.purchased_product = purchased_product;
                        coll.append(customers);
                 }

顧客に関するデータで customers オブジェクトを作成しました。このオブジェクトは、GridDB コンテナに追加されます。

GridDB からデータを取得する

いよいよGridDBコンテナからデータを取り出します。以下のコードを使用します。


                Query<customers> query = coll.query("select *");
                RowSet<customers> rs = query.fetch(false);
                RowSet res = query.fetch();

select *文は、データベースコンテナからすべてのデータを問い合わせるのに役立ちます。

分類器の構築

いよいよ KNN アルゴリズムと読み込まれたデータを使って分類器を構築する時が来ました。そのために必要なライブラリをインポートしましょう。

import java.io.IOException;
import java.util.Enumeration;
import java.text.DecimalFormat;

import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.classifiers.lazy.IBk;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.converters.ArffLoader;

それでは、モデルを構築し、その統計情報をプリントアウトしてみましょう。

res.setClassIndex(res.numAttributes() - 1);
        Classifier cls = new IBk(1);        
        cls.buildClassifier(res);
    
        System.out.println(cls);
       
        Evaluation evaluation = new Evaluation(res);
        evaluation.evaluateModel(cls, res);
        
        System.out.println(evaluation.toSummaryString());
        System.out.println(evaluation.toClassDetailsString());
        System.out.println(evaluation.toMatrixString());

IBk インスタンスを作成する際に k の値を指定しました。IBk インスタンスは整数の引数を取ります。1 という値を渡すと、1 つの最近傍を見つけることができます。2 を渡すと、2 つの最近傍を計算します。引数を渡さず、デフォルトのコンストラクタで呼び出した場合は、1近傍を計算します。今回の場合、1という値を渡しているので、顧客の最近傍を1件予測することになります。

コードのコンパイルと実行

まず、gsadm ユーザでログインします。作成した .java ファイルを GridDB の bin フォルダに移動します。移動先は以下の通りです。

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

次に、Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

次に、以下のコマンドを実行して .java ファイルをコンパイルします。

javac KNNeighbor.java

以下のコマンドを実行して生成された .class ファイルを実行します。

java KNNeighbor

KNNモデルは、顧客の最近傍を1つ返します。

ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb

Leave a Reply

Your email address will not be published. Required fields are marked *