Javaによるランダムフォレストの実装方法

ランダムフォレストは、分類などに用いられる機械学習アルゴリズムの一つです。本ブログでは、アヤメ科植物のクラスを予測するランダムフォレストアルゴリズムのJavaによる実装を説明します。そのためにまずは要件定義を行い、必要なパッケージをインポートします。次に、アイリスデータセットを提示し、Weka ライブラリを用いてランダムフォレストアルゴリズムを実装します。ファイルからデータを取得し、GridDBに格納します。そしてデータを取得し、ランダムフォレストアルゴリズムを実行します。最後に、その結果について考察します。

必要条件

以下のセクションで紹介するコードは、データセットの保存と取得にGridDBを使用します。そのため、GridDBをダウンロードし、ノードを作成し、クラスタに参加する必要があります。Ubuntu 18.04でGridDBを利用する場合は、環境変数の更新を忘れないようにしてください。

export GS_HOME=$PWD
export GS_LOG=$PWD/log
export PATH=${PATH}:$GS_HOME/bin

同じことをgridstore.jarパッケージに対して行います。

export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar 

ランダムフォレストアルゴリズムについては、Wekaライブラリをダウンロードし、その場所を示すようにJavaの環境変数を更新します。

export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar

Javaコードのレベルでは、GridDBクラスタに接続し、GridDBストアとコンテナを作成します。また、コンテナスキーマ、コレクション、カラムを定義する必要があります。これを実現するためのコードを以下に示します。

            // Manage connection to GridDB
            Properties properties = new Properties();
            properties.setProperty("notificationAddress", "239.0.0.1");
            properties.setProperty("notificationPort", "31999");
            properties.setProperty("clusterName", "cluster");
            properties.setProperty("database", "public");
            properties.setProperty("user", "admin");
            properties.setProperty("password", "admin");

            //Get Store and Container
            GridStore store = GridStoreFactory.getInstance().getGridStore(properties);
            store.getContainer("newContainer");
 
            String containerName = "mContainer";
       
            // Define container schema and columns
            ContainerInfo containerInfo = new ContainerInfo();
            List<columninfo> columnList = new ArrayList();
            columnList.add(new ColumnInfo("key", GSType.INTEGER));
            columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
            columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
            columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
            columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
            columnList.add(new ColumnInfo("irisclass", GSType.STRING));
 
            containerInfo.setColumnInfoList(columnList);
            containerInfo.setRowKeyAssigned(true);
            Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
            List<row> rowList = new ArrayList();

このアプリケーションでは、主に4つのパッケージのクラスを使用しています。

  • java.util: ArrayList、List、Random、Scanner などのユーティリティクラスが含まれる。
  • java.io: ファイルからデータセットを読み込むための入出力操作を行うことができます。
  • com.toshiba.mwcloud.gs: GridDB とのデータインタラクションに使用します。
  • weka.classifier.trees: ランダムフォレストアルゴリズムを実装するためのクラスが含まれています。

以下は、パッケージをインポートするコードです。これらのクラスは、以下のセクションで様々な場面で使用されます。

// ---------- Java Util ---------
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import java.util.Scanner;
 
// ---------- Java IO ---------
import java.io.IOException;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;
 
// ---------- GridDB ---------
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;
 
//----------- Weka ---------
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.Evaluation;

データセット

本記事で使用するデータセットは、Wekaのダウンロード時に提供されるデフォルトデータセットから取得したIrisデータセットです。このデータセットには、アイリスという植物のサンプルから収集した150インスタンスのデータが含まれています。このデータセットには、萼の長さ、萼の幅、花弁の長さ、花弁の幅、植物のクラス(Iris Setosa, Iris Versicolour, Iris Virginica)の5つの属性が含まれています。各属性はデータセットの各列にそれぞれ対応します。以下はその抜粋です。

5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica

ランダムフォレストアルゴリズムを構築する前に、まずCSVファイルからデータセットを取得し、GridDBに書き込むことから始めます。

以下のコードでは、まずデータセットが格納されたCSVファイルをオープンします。そして、Scanner を使ってファイルを繰り返し読み、内容を抽出します。ファイルの区切り記号は、改行は n 、カラムは , です。

whileループの各反復で、データを文字列の配列に格納し、それを対応するデータ型にキャストして専用の変数に格納します。

            // Handling Dataset and storage to GridDB
            File data = new File("/home/ubuntu/griddb/gsSample/iris.csv");
            Scanner sc = new Scanner(data);
            sc.useDelimiter("\n");
            while (sc.hasNext())  //returns a boolean value
            {
                Row row = collection.createRow();

                String line = sc.next();
                String columns[] = line.split(",");
                float slenght = Float.parseFloat(columns[0]);
                float swidth = Float.parseFloat(columns[1]);
                float plenght = Float.parseFloat(columns[2]);
                float pwidth = Float.parseFloat(columns[3]);
                String irisclass = columns[4];
            }

後ほどデータセットから得られたデータをGridDBに書き込みます。

ランダムフォレストアルゴリズムを適用し、アヤメの花の萼片と花弁の寸法からアヤメの亜種を予測します。その実装については、次のセッションで説明します。

ランダムフォレストアルゴリズムのJavaによる実装

アイリスデータセットにランダムフォレストアルゴリズムを用いることで、決定木に比べてアイリス亜種の予測精度を向上させることを目的としています。実際、ランダムフォレストアルゴリズムは、より複雑なアルゴリズムを使用して予測を生成するため、決定木の改良版として紹介されています。ランダムフォレストは、その名前が示すように、より代表的であるように、結果を構築するために複数の決定木を使用します。この2つのアルゴリズムの違いは、決定木が貪欲法を用いて各ノードで決定を下すのに対し、ランダムフォレストは入力データから無作為の部分集合を選んで決定を下す点です。この違いが、実際にアイリスの亜種を予測する精度を向上させることになることを観察します。ランダムフォレストアルゴリズムを実装するために、我々はWekaライブラリを使用します。ランダムフォレストアルゴリズムは weka.classifiers.trees パッケージにあり、他の記事で紹介したJ48決定木のような他の分類アルゴリズムも含まれています。

GridDBへのデータ書き込み

これまでの節をよく見てみると、データセットから取得したデータをGridDBに書き込むためのGridDBデータ型として、RowList<Row>を定義しています。これを実現するのが、以下のコードです。

row.setInteger(0,i);
row.setFloat(1,slenght );
row.setFloat(2, swidth);
row.setFloat(3, plenght);
row.setFloat(4, pwidth);
row.setString(5, irisclass);
 
rowList.add(row);

GridDBにデータを格納する

GridDBへのデータの格納は、データセットのカラムとGridDBで定義されたコンテナスキーマの対応付けにより実現されます。以下のコードを思い出すと、カラムが直接マッピングされていることがわかります。まず、整数型のキーカラム、次に各属性を表す4つの浮動小数点変数、そして最後に文字列である虹彩のクラスがあります。前のセクションで使用したコードでは、各カラムにデータを挿入することができます。

 columnList.add(new ColumnInfo("key", GSType.INTEGER));
 columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
 columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
 columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
 columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
 columnList.add(new ColumnInfo("irisclass", GSType.STRING));

以下のコードを実行すると、GridDBにデータが格納されます。

rowList.add(row);

GridDBからデータを取得する

データがGridDBに正しく格納されたことを確認するために、データベース内のすべてのレコードを取得するクエリを実行します。この操作を行うコードは以下の通りです。

        // Retrieving data from GridDB
        Container container = store.getContainer(containerName);
        if ( container == null ){
            throw new Exception("Container not found.");
        }
        Query<row> query = container.query("SELECT * ");
        RowSet<row> rowset = query.fetch();

得られたデータは rowset 変数に格納され、以下のように簡単にデータを取得したり、印刷したりすることができるようになります。

        // Print GridDB data
        while ( rowset.hasNext() ) {
            Row row = rowset.next();
            float slenght = row.getFloat(0);
            float swidth = row.getFloat(1);
            float plenght = row.getFloat(2);
            float pwidth = row.getFloat(3);
            String irisclass = row.getString(4);
            System.out.println(" slenght=" + slenght + ", swidth=" + swidth + ", plenght=" + plenght +", pwidth=" + pwidth+", irisclass=" + irisclass);
        }

ランダムフォレストを構築する

本記事で得られる結果は、ランダムフォレスト分類器を以下のパラメータで実行したものです。

P 100 -I 100 -num-slots 1 -K 0 -M 1.0 -V 0.001 -S 1

このアルゴリズムを Java で利用するには、まず RandomForest クラスのオブジェクトを作成することから始めます。

 RandomForest randomForest = new RandomForest();

次に、ランダムフォレストアルゴリズムのパラメーターの配列を指定します。これは次のようなコードで実現します。

String[] parameters = new String[14];
     
parameters[0] = "-P";
parameters[1] = "100";
parameters[2] = "-I";
parameters[3] = "100";
parameters[4] = "-num-slots";
parameters[5] = "1";
parameters[6] = "-K";
parameters[7] = "0";
parameters[8] = "-M";
parameters[9] = "1.0";
parameters[10] = "-V";
parameters[11] = "0.001";
parameters[12] = "-S";
parameters[13] = "1";

randomForest.setOptions(parameters);

Wekaのランダムフォレストパラメータの詳細については、公式のJava Random Forest クラスをご確認ください。

この後,学習データセットを用いて分類器を構築します.この時点で,Evaluation クラスを利用してアルゴリズムを評価する準備が整いました。そして、クロスバリデーションによりモデルを評価し、コマンドラインに出力する予測結果を得ます。これは次のようなコードで実現されます。

randomForest.setOptions(parameters);
randomForest.buildClassifier(datasetInstances);
 
Evaluation evaluation = new Evaluation(datasetInstances);
evaluation.crossValidateModel(randomForest, datasetInstances, numFolds, new Random(1));

System.out.println(evaluation.toSummaryString("\nResults\n======\n", true));

コードの準備ができたので、Javaコードのコンパイルと実行に進みます。

コードのコンパイルと実行

コマンドラインから GridDB フォルダに移動し、コマンドを実行します。Java コードは、gsSample フォルダ下の randomForest.java ファイルに格納されています。まず、このファイルをコンパイルし、次に実行します。これは以下のコマンドで実現できます。

javac gsSample/randomForest.java
java gsSample/randomForest.java

結論と結果

コードをコンパイルして実行すると、以下のような結果がコマンドラインに出力されます。また、Summaryセクションのみを印刷するオプションもあります。他のセクション、例えばアルゴリズムの実行の文脈を見ることができる「Run information」などを見ることができます。混同行列は、クラスに応じて分類されたインスタンスの詳細なカウントを提供します。

ランダムフォレストアルゴリズムを使用してIrisクラスを予測すると、精度は95.3%に達し、150のインスタンスのうち143が正しく分類されました。前回の記事では、決定木を使用した場合、精度は約94%でした。

=== Run information ===

Scheme:       weka.classifiers.trees.RandomForest -P 100 -I 100 -num-slots 1 -K 0 -M 1.0 -V 0.001 -S 1
Relation:     iris
Instances:    150
Attributes:   5
              sepallength
              sepalwidth
              petallength
              petalwidth
              class
Test mode:    10-fold cross-validation

=== Classifier model (full training set) ===

RandomForest

Bagging with 100 iterations and base learner

weka.classifiers.trees.RandomTree -K 0 -M 1.0 -V 0.001 -S 1 -do-not-check-capabilities

Time taken to build model: 0.01 seconds

=== Stratified cross-validation ===
=== Summary ===

Correctly Classified Instances         143               95.3333 %
Incorrectly Classified Instances         7                4.6667 %
Kappa statistic                          0.93  
Mean absolute error                      0.0408
Root mean squared error                  0.1621
Relative absolute error                  9.19   %
Root relative squared error             34.3846 %
Total Number of Instances              150     

=== Detailed Accuracy By Class ===

                 TP Rate  FP Rate  Precision  Recall   F-Measure  MCC      ROC Area  PRC Area  Class
                 1,000    0,000    1,000      1,000    1,000      1,000    1,000     1,000     Iris-setosa
                 0,940    0,040    0,922      0,940    0,931      0,896    0,991     0,984     Iris-versicolor
                 0,920    0,030    0,939      0,920    0,929      0,895    0,991     0,982     Iris-virginica
Weighted Avg.    0,953    0,023    0,953      0,953    0,953      0,930    0,994     0,989     

=== Confusion Matrix ===

  a  b  c   <-- classified as
 50  0  0 |  a = Iris-setosa
  0 47  3 |  b = Iris-versicolor
  0  4 46 |  c = Iris-virginica

ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb

Leave a Reply

Your email address will not be published. Required fields are marked *