ランダムフォレストは分類や回帰の問題を解決するために用いられる教師ありの機械学習アルゴリズムです。このアルゴリズムは、複雑な問題を解決し、モデルの性能を向上させるために、多くの分類器を組み合わせるプロセスであるアンサンブル学習の概念を使用しています。
ランダムフォレスト分類器は、特定のデータセットの様々なサブセットに対する多数の決定木を含み、その平均値を用いてデータセットの予測精度を向上させます。ランダムフォレストは1つの決定木に依存せず、各決定木からの予測を取り込み、予測の多数決に基づいて最終的な予測を行います。
フォレスト内のフォレストの数が多いほど、アルゴリズムの精度が向上し、オーバーフィッティングの可能性が低くなります。今回はJavaのランダムフォレストアルゴリズムとGridDBを使って、ユーザーの性別、年齢、給与からSUVを購入するかどうかを予測します。
GridDBにデータを格納する
使用するデータは、性別、年齢、給与、SUVを購入したかどうかです。このデータは、”suv.csv “というCSVファイルに格納されています。このデータをGridDBに格納し、クエリの性能向上などのメリットを享受したいと思います。
まず、これに使用するライブラリ一式をインポートしましょう。
import java.io.File;
import java.util.Scanner;
import java.io.IOException;
import java.util.Properties;
import java.util.Collection;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStoreFactory;
それでは、データを格納する GridDB コンテナを表す静的 Java クラスを作成しましょう。
public static class SuvData {
@RowKey int user_id;
String gender;
int age;
int estimatedSalary;
int purchased;
}
GridDBコンテナの5つのカラムは、上記のクラスではそれぞれ変数として表現されています。
GridDBコンテナへの接続を支援するJavaコードを書いてみましょう。ここでは、GridDBをインストールした際の認証情報を使用します。
Properties props = new Properties();
props.setProperty("notificationAddress", "239.0.0.1");
props.setProperty("notificationPort", "31999");
props.setProperty("clusterName", "defaultCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
GridDB への接続が確立されたので、これから扱うコンテナを選択しましょう。コンテナの名前は SuvData
とします。
Collection<String, SuvData> coll = store.putCollection("col01", SuvData.class);
コンテナのインスタンスを作成し、coll
という名前を付けました。この名前を用いてコンテナを参照することになります。
今度は suv.csv
ファイルから GridDB コンテナにデータを書き込みます。
File file1 = new File("suv.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();
while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");
String user_id = dataList[0];
String gender = dataList[1];
String age = dataList[2];
String estimatedSalary = dataList[3];
String purchased = dataList[4];
SuvData sd = new SuvData();
sd.user_id = Integer.parseInt(user_id);
sd.gender = gender;
sd.age = Integer.parseInt(age);
sd.estimatedSalary = Integer.parseInt(estimatedSalary);
sd.purchased = Integer.parseInt(purchased);
coll.append(sd);
}
このコードでは、指定されたGridDBコンテナにデータを追加します。
データを取得する
GridDBに格納されたデータを使って、ランダムフォレストの機械学習モデルを学習させるためにデータを取得する必要があります。以下のコードで、GridDBコンテナからデータを取得することができます。
Query<suvdata> query = coll.query("select *");
RowSet<suvdata> rs = query.fetch(false);
RowSet res = query.fetch();
ここでは、select *
ステートメントを使用して、GridDBコンテナに格納されているすべてのデータを取得しました。
ランダムフォレスト分類器の実装
このデータを使って、ランダムフォレスト分類器を作ることができます。ここでは、Wekaが提供するランダムフォレスト分類器を使用します。まずはその為のライブラリのインポートから始めましょう。
import java.io.File;
import weka.core.Instances;
import weka.filters.Filter;
import java.io.IOException;
import java.io.FileReader;
import java.io.BufferedReader;
import weka.core.Instance;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.converters.ArffLoader;
import weka.classifiers.trees.RandomForest;
import weka.filters.unsupervised.attribute.StringToWordVector;
データをメモリバッファに格納し、インスタンスに変換してみましょう。
BufferedReader bufferedReader = new BufferedReader(new FileReader(res));
// Create dataset instances
Instances datasetInstances = new Instances(bufferedReader);
このデータセットを使って、10本の木からなるランダムフォレスト分類器を作ることができます。データセットの最後のインスタンスを予備として、それを使って予測をしてみましょう。
datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);
RandomForest forest=new RandomForest();
forest.setNumTrees(10);
forest.buildClassifier(datasetInstances);
それでは、このモデルを評価し、そのサマリー統計を求めてみましょう。
Evaluation eval = new Evaluation(datasetInstances);
eval.evaluateModel(forest, datasetInstances);
System.out.println("Random Forest Classifier Evaluation Summary");
System.out.println(eval.toSummaryString());
System.out.print(" The expression for the input data as per algorithm is: ");
System.out.println(forest);
予測する
これで、データセットの最後のインスタンスを使って予測を行うことができます。Weka ライブラリの classifyInstance()
関数は、ユーザが SUV を買うかどうかを知るのに役立ちます。これは以下のようになります。
Instance pred = datasetInstances.lastInstance();
double answer = forest.classifyInstance(pred);
System.out.println(answer);
モデルのコンパイルと実行
上記のモデルをコンパイルして実行するために、Weka APIを使用することにします。APIは以下のURLからダウンロードしてください。
http://www.java2s.com/Code/Jar/w/weka.htm
次に、gsadm
ユーザでログインします。作成した .java
ファイルを GridDB の bin
フォルダに移動します。移動先は以下の通りです。
/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin
Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。
export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar
以下のコマンドを実行して、.java
ファイルをコンパイルしてください。
javac -cp weka-3-7-0/weka.jar RandomForestAlgorithm.java
以下のコマンドを実行して生成された .class ファイルを実行します。
java -cp .:weka-3-7-0/weka.jar RandomForestAlgorithm
この予測は、ユーザーがSUVを購入することを示しています。
ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb