JavaとGridDBを用いたナイーブベイズ分類器

はじめに

ナイーブベイズアルゴリズムは、ベイズの定理に基づく分類手法です。これは予測変数が互いに独立であることを仮定しています。ナイーブベイズ分類器は、あるクラスにおけるある特徴の存在が、他のどの特徴の存在とも関係しないことを仮定します。

例えば、リンゴの果実は、赤い色、丸い形、直径約3インチという特徴があります。これらの特徴は互いに依存しあっていますが、独立にその果実がリンゴである確率に寄与しています。だから “ナイーブ” と呼ばれるのです。

ナイーブベイズは、簡単に構築できるモデルであり、非常に大きなデータセットにも十分に適用できます。そのシンプルさにも関わらず、ナイーブベイズは最も洗練された分類アルゴリズムでさえも凌駕しています。

今回は、JavaとGridDBを使って、ナイーブベイズ分類器を実装する方法を説明します。目標は顧客が日にち・割引・無料配送といった情報に基づいて製品を購入するかどうかを予測することです。

GridDBにデータを格納する

データは “shopping.csv “という名前のCSVファイルに保存されています。このデータをGridDBに移行し、クエリの性能向上など、その利点を享受したいと思います。

これに使用するライブラリをインポートしてみましょう。

import java.io.File;
import java.io.IOException;
import java.util.Properties;
import java.util.Collection;
import java.util.Scanner;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

次に、データを格納する GridDB コンテナを表す静的 Java クラスを作成します。

public static class ShoppingData {
     @RowKey String  day;
     String discount;
     String free_delivery;
     String purchase;
    } 

上記の Java クラスを 4 つのカラムを持つ SQL テーブルとして見てください。4つの変数は、GridDBコンテナのカラムを表しています。

それでは、JavaからGridDBコンテナに接続してみましょう。ここでは、GridDBをインストールした際の認証情報を使用します。

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

コンテナには “ShoppingData “という名前がついています。それを選択してみましょう。

Collection<String, ShoppingData> coll = store.putCollection("col01", ShoppingData.class);

ここでは、ShoppingData コンテナを参照するために coll という名前を使用します。

それでは、shopping.csvのデータをGridDBに書き込んでみましょう。

                File file1 = new File("shopping.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String day = dataList[0];
                        String discount = dataList[1];
                        String free_delivery = dataList[2];
                        String purchase = dataList[3];
                        
                        ShoppingData sd = new ShoppingData();
                        sd.day = day;
                        sd.discount= discount;
                        sd.free_delivery = free_delivery;
                        sd.purchase = purchase;
                        
                        coll.append(sd);
                 }

上記のコードで、GridDBコンテナにデータが追加されます。

データを取得する

GridDBからデータを取得し、それを使ってナイーブベイズ分類器を実装することができます。以下のコードでデータを取得することができます。

                Query<shoppingdata> query = coll.query("select *");
                RowSet<shoppingdata> rs = query.fetch(false);
                RowSet res = query.fetch();

select * ステートメントにより、コンテナに格納されているすべてのデータを取得することができました。

ナイーブベイズ分類器の実装

さて、データが揃ったので、それを使ってナイーブベイズアルゴリズムを使った機械学習モデルを学習してみましょう。ここではWekaライブラリを使用します。まず、モデルの学習に使用するライブラリをすべてインポートしましょう。

import weka.core.Instances;
import weka.filters.Filter;
import java.io.FileReader;
import java.io.BufferedReader;
import weka.classifiers.Evaluation;
import weka.classifiers.Classifier;
import weka.core.converters.ArffLoader;
import weka.classifiers.bayes.NaiveBayesMultinomial;
import weka.filters.unsupervised.attribute.StringToWordVector;

データセット用のバッファードリーダーとインスタンスを作成しましょう。

            BufferedReader bufferedReader
                = new BufferedReader(
                    new FileReader(res));

            // Create dataset instances
            Instances datasetInstances
                = new Instances(bufferedReader);

それでは、ナイーブベイズ用の多項式Weka分類器を使って、モデルの構築と評価を行ってみましょう。

        datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);

        Classifier classifier = new NaiveBayesMultinomial();
        classifier.buildClassifier(datasetInstances);
    
        Evaluation eval = new Evaluation(datasetInstances);
        eval.evaluateModel(classifier, datasetInstances);

        System.out.println("Naive Bayes Classifier Evaluation Summary");
        System.out.println(eval.toSummaryString());
        System.out.print(" the input data expression as per the alogorithm is ");
        System.out.println(classifier);

予測する

データセットの最後のインスタンスをモデルの訓練に使いませんでした。それを使って予測を行いたいと思います。以下のように、Weka ライブラリの classifyInstance() 関数を使用することにします。

        Instance pred = datasetInstances.lastInstance();
        double answer = classifier.classifyInstance(pred);
        System.out.println(answer);

モデルのコンパイルと実行

上記のナイーブベイズ分類器をコンパイルして実行するには、Weka APIが必要です。以下のURLからダウンロードしてください。

http://www.java2s.com/Code/Jar/w/weka.htm

次に、gsadm ユーザでログインします。作成した .java ファイルを GridDB の bin フォルダに移動します。移動先は以下の通りです。

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

次に、以下のコマンドを使用して .java ファイルをコンパイルします。

javac -cp weka-3-7-0/weka.jar NaiveBayesClassifierExample.java

以下のコマンドを実行して生成された.classファイルを実行します。

java -cp .:weka-3-7-0/weka.jar NaiveBayesClassifierExample

予測結果は、お客様が購入することを示しています。

ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb

Leave a Reply

Your email address will not be published. Required fields are marked *