決定木を使ってジムに行くかを予測する

はじめに

決定木は分類問題と回帰問題の両方を解決するために使用できる教師あり機械学習アルゴリズムです。このアルゴリズムの目的は、学習データから推測される決定規則を学習することで、対象変数の値やクラスを予測できるモデルを作成することです。

決定木を使ってレコードのクラスラベルを予測する場合、ツリーのルートから始めます。そして、ルートの属性値とレコードの属性値を比較します。比較の結果、その値に対応する枝をたどり、次のノードに進みます。

決定木アルゴリズムは、ルートノードからリーフ/ターミナルノードへと適切に並べ替え、対象を分類します。

この記事では、決定木アルゴリズムをJavaで実装する方法について説明します。年齢と体重から、ある個人がジムに通うかどうかを予測します。

GridDBにデータを格納する

これまで、データは gym.csv という名前の .csv ファイルに保存されていましたが、GridDB に移動する必要があります。CSVファイルのデータも利用できますが、GridDBの方がクエリのパフォーマンスが向上するなど、様々なメリットがあります。

データセットには2つの独立変数、すなわち年齢体重と1つの従属変数、ジムがあります。従属変数の値 1 は、その人がジムに行くことを示し、値 0 は、その人がジムに行かないことを示します。

まず、使用するライブラリ一式をインポートしましょう。

import java.io.File;
import java.io.IOException;
import java.util.Properties;
import java.util.Collection;
import java.util.Scanner;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

ここで、使用する GridDB コンテナを表す静的 Java クラスを作成します。

public static class GymData {
     @RowKey int  age;
     int weight; 
     int gym;
    } 

上記のJavaクラスは、3つのカラムを持つSQLテーブルと同等です。3つの変数は、GridDBコンテナのカラムを表します。

それでは、JavaとGridDBコンテナとの接続を確立してみましょう。ここでは、GridDBをインストールした際の認証情報を使用します。

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

ここでは、GymDataという名前のGridDBコンテナを使用します。それを選択してみましょう。

Collection<String, GymData> coll = store.putCollection("col01", GymData.class);

ここでは、GymData コンテナを参照するために、coll という名前を使用します

それでは、gym.csvファイルからGridDBにデータを移動してみましょう。

                File file1 = new File("gym.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String age = dataList[0];
                        String weight = dataList[1];
                        String gym = dataList[2];
                        
                        GymData gd = new GymData();
                        gd.age = Integer.parseInt(age);
                        gd.weight = Integer.parseInt(weight);
                        gd.gym = Integer.parseInt(gym);
                        
                        coll.append(gd);
                }

上記のコードで、GridDBコンテナにデータが追加されます。

データを取得する

ここでは、GridDBコンテナからデータを取得し、それを使って決定木アルゴリズムを使ったモデルを実装してみたいと思います。以下のコードで、データを取得することができます。

                Query<GymData> query = coll.query("select *");
                RowSet<GymData> rs = query.fetch(false);
                RowSet res = query.fetch();

select *select allを意味し、コンテナに格納されているすべてのデータを取り出すのに役立ちます。

決定木モデルの実装

目標は、そのデータを使って、個人がジムに行くかどうかを予測できる機械学習モデルを訓練することです。モデルの学習には、決定木アルゴリズムを使用します。Wekaライブラリから必要なライブラリをインポートしましょう。

import java.io.IOException;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.Instance;
import weka.core.converters.ArffLoader; 

import java.io.FileReader;
import java.io.BufferedReader;

import java.io.IOException;
import weka.classifiers.trees.Id3;
import weka.classifiers.trees.J48;
import weka.core.converters.ArffLoader;

これで、データセット用のバッファードリーダーとインスタンスを作成することができます。

            BufferedReader bufferedReader
                = new BufferedReader(
                    new FileReader(res));
 
            // Create dataset instances
            Instances datasetInstances
                = new Instances(bufferedReader);

それでは、Weka ライブラリの buildClassifier() 関数を呼び出して、決定木分類器を構築してみましょう。

            datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);
            
            Classifier myclassifier = new J48();
            myclassifier.buildClassifier(datasetInstances);
            System.out.println(myclassifier);

データセットの最後のインスタンスは,分類器の構築に使用されていません。

予測する

上記のモデルと、データセットの最後のインスタンスを使って、予測をしてみましょう。以下のように、Wekaライブラリの classifyInstance() 関数を使用します。

        Instance pred = datasetInstances.lastInstance();
        double answer = myclassifier.classifyInstance(pred);
        System.out.println(answer);

モデルのコンパイルと実行

モデルをコンパイルして実行するには、Weka API が必要です。以下のURLからダウンロードしてください。

http://www.java2s.com/Code/Jar/w/weka.htm

次に、gsadm ユーザでログインします。作成した .java ファイルを GridDB の bin フォルダに移動します。移動先は以下の通りです。

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

次に、以下のコマンドを使用して .java ファイルをコンパイルします。

javac -cp weka-3-7-0/weka.jar DecisionTreeAlgorithm.java

以下のコマンドを実行して生成された .class ファイルを実行します。

java -cp .:weka-3-7-0/weka.jar DecisionTreeAlgorithm

モデルは予測値として「0」を返しました。これは、その人がジムに行かないことを意味します。

ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb

Leave a Reply

Your email address will not be published. Required fields are marked *