決定木アルゴリズムのJavaでの実装方法

統計的手法、機械学習、人工知能の組み合わせにより、多くの分野で支援システムとなる様々なデータマイニング技術が開発されています。これらの手法は、様々な分野や産業で不可欠なものとなりつつあります。

この記事では、Javaによる決定木アルゴリズムの実装について説明します。まず、この記事で使用されているデータセットについて説明します。その後、Weka J48 決定木を使用したアルゴリズムの実装を紹介します。その後、メインデータベースであるGridDBとの異なるインタラクションを説明します。この場合、我々はオリジナルのデータセットをGridDBに書き込み、データを保存し、後で決定木を構築するためにそれを取り出すという使い方をします。

必要条件

ここでは、本記事で使用する要件と構成について説明します。

Weka 3.9: weka.jar ファイルをダウンロードし、/usr/share/java/パスに配置します。

GridDB 4.6: インストール後、GridDB クラスタがアクティブである必要があります。

CLASSPATHにWekaライブラリのパスを追加することを確認します。GridDBについても同様の操作を行います。以下は対応するコマンドラインです。

$ export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar
$ export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar

データセット

この記事では、アヤメのデータセットを使用することにしました。このデータセットはオープンソースでこちらに公開されており、5つの属性を持つ149のエントリで構成されています。

  1. 萼片の長さ(cm): アヤメ科植物の萼片の長さを測定したもの
  2. 萼片の幅(cm): アヤメ科植物の萼片の幅を測定したもの
  3. 花弁の長さ(cm): アヤメ科植物の花弁の長さを測定したもの
  4. 花弁の幅(cm): アヤメ科植物の花弁の幅を測定したもの
  5. クラス: アヤメの亜種を示す。設定可能な値は以下の通り。Iris Setos, Iris Versicolour, Iris Virginica.

以下はデータセットの抜粋です。

4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor

データセットは.CSVファイルの形で提供され、それを読み込み、パースし、将来GridDBコレクションに保存するために一時的に変数に格納されます。以下のコードでこれらの処理を行います。

// Handlig Dataset and storage to GridDB
File data = new File("/home/ubuntu/griddb/gsSample/iris.csv");
Scanner sc = new Scanner(data);  
sc.useDelimiter("\n");
while (sc.hasNext())  //returns a boolean value  
{  
    Row row = collection.createRow();

    String line = sc.next();
    String columns[] = line.split(",");
    float slenght = Float.parseFloat(columns[0]);
    float swidth = Float.parseFloat(columns[1]);
    float plenght = Float.parseFloat(columns[2]);
    float pwidth = Float.parseFloat(columns[3]);
    String irisclass = columns[4];
}

終了したら、スキャナを閉じるのを忘れないでください。

sc.close();

花弁と萼の寸法からアヤメの亜種を予測するために、決定木を実装しました。以下では、Javaによる決定木の実装を説明します。

Javaによる決定木アルゴリズムの実装

前項で述べたように、この記事ではWekaパッケージで利用可能なJ48決定木を使用します。このクラスは、刈り込み済みまたは刈り込みなしのC4.5決定木を生成します。では、その実装を詳しく見ていきましょう。まず、Wekaに必要なパッケージをインポートすることから始めます。

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.J48;
import weka.classifiers.Evaluation;

J48決定木アルゴリズムには、データセットとアルゴリズムをマッチングさせるために微調整できる一連の属性があります。今回の場合、以下の2つを設定しました。

-C . 枝刈りの信頼度。枝刈りのための信頼性閾値を設定。

-M . 最小インスタンス数。葉ごとの最小インスタンス数を設定。

ここでは、枝刈りの信頼度を0.25とし、最小インスタンス数を30とします。これは以下のコードで実現されます。

String[] options = new String[4];
options[0] = "-C";
options[1] = "0.25";
options[2] = "-M";
options[3] = "30";

決定木が完成しました。では、データを用意し、それを使ってアルゴリズムを構築していきます。

パッケージのインポート

java.utilパッケージのさまざまなクラスを使用します。ArrayListとListはデータを整理するために使います。Propertiesクラスは、GridDBストアのインスタンスにクラスタ接続パラメータを渡すために使用されます。最後に、Randomクラスは、決定木の構築における交差検証フェーズで、シードパラメータをランダム化するために使用されます。

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;

GridDBと接続し、対話するために、さまざまなパッケージがインポートされます。

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;

最後に、元のデータセットを含むファイルを操作するためのパッケージをインポートします。

import java.io.IOException;
import java.util.Scanner;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;

GridDBにデータを書き込む

GridDBにデータを書き込むために、まず、接続プロパティを目的のクラスタに設定します。同様に、GridStoreインスタンスを作成し、コンテナ名containerNameで参照されるコンテナを保持します。

Properties prop = new Properties();
prop.setProperty("notificationAddress", "239.0.0.1");
prop.setProperty("notificationPort", "31999");
prop.setProperty("clusterName", "cluster");
prop.setProperty("database", "public");
prop.setProperty("user", "admin");
prop.setProperty("password", "admin");

GridStore store = GridStoreFactory.getInstance().getGridStore(prop);
store.getContainer("newContainer");
String containerName = "last";

これでストアとコンテナの作成に成功したので、コンテナ情報を設定し、カラムのリストとそれに対応するデータ型を設定して、コンテナスキーマの定義を開始することができます。

// Define ontainer schema and columns
ContainerInfo containerInfo = new ContainerInfo();
List<columninfo> columnList = new ArrayList</columninfo><columninfo>();
columnList.add(new ColumnInfo("key", GSType.INTEGER));
columnList.add(new ColumnInfo("slenght", GSType.FLOAT));
columnList.add(new ColumnInfo("swidth", GSType.FLOAT));
columnList.add(new ColumnInfo("plenght", GSType.FLOAT));
columnList.add(new ColumnInfo("pwidth", GSType.FLOAT));
columnList.add(new ColumnInfo("irisclass", GSType.STRING));

containerInfo.setColumnInfoList(columnList);
containerInfo.setRowKeyAssigned(true);

GridDBには、2種類のコンテナが用意されています。今回は、一般的なデータを管理するためのCollection型を選択します。この目的のために、コレクションを作成し、データを整理するための行のリストを作成します。また、元の CSV ファイルからデータを取得するために、各反復処理でデータ行を保持する Row クラスをインスタンス化する予定です。

Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
List<row> rowList = new ArrayList<row>();
Row row = collection.createRow();

GridDBにデータを格納する

このとき、データセットから得た値を一時的に格納するための変数を呼び出し、各行に挿入します。その後、前節で定義した行リストに各行を追加します。

row.setInteger(0,key);
row.setFloat(1,slenght );
row.setFloat(2, swidth);
row.setFloat(3, plenght);
row.setFloat(4, pwidth);
row.setString(5, irisclass);
rowList.add(row);

GridDBにデータを格納するために、以下のような行を使用します。

collection.put(rowList);

GridDBからデータを取得する

GridDBに格納したデータを取得するために、TQLクエリを実行します。この例では、コンテナ内のすべてのデータを選択します。

Query<row> query = container.query("SELECT * ");
RowSet<row> rs = query.fetch();

データを可視化したい場合は以下のようにします。

while ( rs.hasNext() ) {
    Row row = rs.next();
    float slenght = row.getFloat(0);
    float swidth = row.getFloat(1);
    float plenght = row.getFloat(2);
    float pwidth = row.getFloat(3);
    String irisclass = row.getString(4);
    System.out.println(" slenght=" + slenght + ", swidth=" + swidth + ", plenght=" + plenght +", pwidth=" + pwidth+", irisclass=" + irisclass);
}

決定木の構築

GridDBにデータが格納されたので、このデータを使って決定木を構築する準備ができました。分類木を構築する Weka 関数は、データの Instance をパラメータとして受け取るので、Weka 関数に適したデータにするために、以下のコードを記述する必要があります。

BufferedReader bufferedReader= new BufferedReader(new FileReader(res));
Instances datasetInstances= new Instances(bufferedReader);

このレベルでは、単純に分類器の構築に進みます。

mytree.buildClassifier(datasetInstances);

最後に、構築した分類器ツリーの評価を作成します。

Evaluation eval = new Evaluation(datasetInstances);  eval.crossValidateModel(mytree, datasetInstances, 10, new Random(1));

コードの最後の行で、データインスタンスに対してクロスバリデーションを実行していることがわかります。このプロセスにより、データセットが異なる方法で分割され、偏りのない結果が得られることが保証されます(特に、限られたデータセットで計算する場合)。

最後に、評価結果のサマリーを表示します。

System.out.println(eval.toSummaryString("\n ****** J48 *****\n", true));

コードのコンパイルと実行

私たちのコンテキストでは、.java ファイルが gsSample デフォルトフォルダのレベルにあり、この記事で説明されているコードが含まれています。GridDB フォルダに移動し、以下のコマンドを実行して、コードをコンパイル・実行します。

~/griddb$ javac gsSample/Select.java
~/griddb$ java gsSample/Select.java

このコードを実行する際、コマンドラインに追加のクラスターパラメーターは必要ないことがわかります。

結論

以下は、前節で印刷した評価概要の抜粋です。

---Registering Weka Editors---
Trying to add database driver (JDBC): jdbc.idbDriver - Error, not in CLASSPATH?

 ****** J48 *****

Correctly Classified Instances         141               94.6309 %
Incorrectly Classified Instances         8                5.3691 %
Kappa statistic                          0.9195
K&B Relative Info Score              13340.283  %
K&B Information Score                  211.4385 bits      1.4191 bits/instance
Class complexity | order 0             236.1698 bits      1.585  bits/instance
Class complexity | scheme             2179.3992 bits     14.6268 bits/instance
Complexity improvement     (Sf)      -1943.2293 bits    -13.0418 bits/instance
Mean absolute error                      0.0578
Root mean squared error                  0.1831
Relative absolute error                 13.0031 %
Root relative squared error             38.8358 %
Total Number of Instances              149

このように、決定木はアヤメをクラス分けする際に、94.6%の精度を達成することができたのです。結果とその解釈についての詳細は、Weka J48の公式ドキュメントをご覧ください。

TQL クエリ、コンテナ、GridDB ストアを閉じるのを忘れないでください。

query.close();
container.close();
store.close();

ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb

Leave a Reply

Your email address will not be published. Required fields are marked *