ランダムフォレストは分類予測に使用できる強力なモジュールで、決定木よりも良い性能を発揮します。この記事では、ランダムフォレストモジュールとクリーンなデータセットを組み合わせて、ある大学院への大学生の入学可能性を予測することに焦点を当てます。この問題は、業界でよく知られた予測セットです。機械学習アルゴリズムを実装するために、Javaプログラミングを使用する予定です。この記事では、データマイニングのプロセスについて、業界標準のプロセスに従います。最初に、データを取得し、それを準備し、予測のためのモジュールを作成し、それを実行し、最後に結果を評価することによってそれをテストします。
ソースコードとデータはこちら: https://github.com/griddbnet/Blogs/tree/admissions
必要条件
ランダムフォレストモジュールを作成するにあたり、二次記憶装置としてGridDBデータベースを使用します。GridDBをダウンロードおよび設定し、GridDBがあなたのシステムで稼働していることを確認することが重要です。
必ず以下のコマンドを実行して、必要な環境変数を更新してください。
export GS_HOME=$PWD
export GS_LOG=$PWD/log
export PATH=${PATH}:$GS_HOME/bin
export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar
export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar
GridDBクラスタが動作していることを確認するために、Javaクラスで実装する必要があります。このタスクは GridDB コンテナを作成することから始める必要があります。次に、データセットのカラムの形で表示されるコンテナ・スキーマを定義します。
上記の作業を実現するために、以下のようなコードを使用します。
// Manage connection to GridDB
Properties properties = new Properties();
properties.setProperty("notificationAddress", "239.0.0.1");
properties.setProperty("notificationPort", "31999");
properties.setProperty("clusterName", "cluster");
properties.setProperty("database", "public");
properties.setProperty("user", "admin");
properties.setProperty("password", "admin");
// Get Store and Container
GridStore store = GridStoreFactory.getInstance().getGridStore(properties);
store.getContainer("newContainer");
String containerName = "mContainer";
// Define container schema and columns
ContainerInfo containerInfo = new ContainerInfo();
List<columninfo> columnList = new ArrayList<columninfo>();
columnList.add(new ColumnInfo("key", GSType.INTEGER));
columnList.add(new ColumnInfo("Serial No.", GSType.INTEGER));
columnList.add(new ColumnInfo("GRE Score", GSType.INTEGER));
columnList.add(new ColumnInfo("TOEFL Score", GSType.INTEGER));
columnList.add(new ColumnInfo("University Rating", GSType.INTEGER));
columnList.add(new ColumnInfo("SOP", GSType.FLOAT));
columnList.add(new ColumnInfo("LOR", GSType.FLOAT));
columnList.add(new ColumnInfo("CGPA", GSType.FLOAT));
columnList.add(new ColumnInfo("Research", GSType.INTEGER));
columnList.add(new ColumnInfo("Chance of Admit", GSType.FLOAT));
containerInfo.setColumnInfoList(columnList);
containerInfo.setRowKeyAssigned(true);
Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
List<row> rowList = new ArrayList<row>();
ランダムフォレストのJavaプログラムでは、4つのライブラリ群からクラスをインポートする必要があります。
java.util
: データセットの読み込みと書き込みに使用します。java.io
: 入出力タスクに使用します。com.toshiba.mwcloud.gs
: GridDB データベースの設定と操作に使用します。weka.classifier.trees
: ランダムフォレストモジュールを実装するために使用します。
上記のタスクを実装するために使用されるJavaコードは以下の通りです。
// ---------- Java Util ---------
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import java.util.Scanner;
// ---------- Java IO ---------
import java.io.IOException;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;
// ---------- GridDB ---------
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;
//----------- Weka ---------
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.Evaluation;
データセット
ランダムフォレストアルゴリズムを実装するために、ある大学への入学を予測するためのadmissionデータセットを用います。このデータセットは9個の属性と400個のインスタンスから構成されます。
このデータセットを構成する属性は、以下の通りです。
- Serial No: 学生を識別するために使用される数値
- GRE Score: 学生のGREスコアを数値化したもの
- TOEFL Score: 学生のTOEFLスコアを数値化したもの
- University Rating: 大学ランキングを数値化したもの
- SOP: 学生のSOPスコアを数値化したもの
- LOR: 学生のLORスコアを数値化したもの
- CGPA: 学生の累積GPAスコアを数値化したもの
- Research: 学生の研究成績を数値化するために使用する2値
- Chance of Admit: 入学の可能性を数値化するために使用される2値
以下はデータセットを一部抜粋したものです。
Serial No.,GRE Score,TOEFL Score,University Rating,SOP,LOR ,CGPA,Research,Chance of Admit
1,337,118,4,4.5,4.5,9.65,1,0.92
2,324,107,4,4,4.5,8.87,1,0.76
3,316,104,3,3,3.5,8,1,0.72
4,322,110,3,3.5,2.5,8.67,1,0.8
5,314,103,2,2,3,8.21,0,0.65
6,330,115,5,4.5,3,9.34,1,0.9
次のステップへの準備を確実にするために、Javaを使ってデータセットファイルを読み込む必要があります。この作業が終わったら、次はこのデータセットをGridDBデータベースを使って長期保存場所に書き込むことになります。
以下のコードでは、上記のステップを実装しています。
File data = new File("/home/ubuntu/griddb/gsSample/Admission_Predict.csv");
Scanner sc = new Scanner(data);
sc.useDelimiter("\n");
while (sc.hasNext()) // Returns a boolean value
{
int i = 0;
Row row = collection.createRow();
String line = sc.next();
String columns[] = line.split(",");
int serial = Integer.parseInt(columns[0]);
int gre = Integer.parseInt(columns[1]);
int toefl = Integer.parseInt(columns[2]);
int rating = Integer.parseInt(columns[3]);
float sop = Float.parseFloat(columns[4]);
float lor = Float.parseFloat(columns[5]);
float cgpa = Float.parseFloat(columns[6]);
int research = Integer.parseInt(columns[7]);
float admitclass = Float.parseFloat(columns[8]);
row.setInteger(0, i);
row.setInteger(1, serial);
row.setInteger(2, gre);
row.setInteger(3, toefl);
row.setInteger(4, rating);
row.setFloat(5, sop);
row.setFloat(6, lor);
row.setFloat(7, cgpa);
row.setInteger(8, research);
row.setFloat(9, admitclass);
rowList.add(row);
i++;
}
Javaによるランダムフォレストアルゴリズムの実装
この記事では、決定木よりも優れたパフォーマンスを発揮する分類モデルを使用します。このアルゴリズムは、ランダムフォレストとして知られています。このモデルを選択した背景を理解するために、決定木と比較した場合の利点を説明する必要があります。まず、ランダムフォレストアルゴリズムは、多くのツリーを作成し、その結果を平均化することで、可能な限り最高の精度を得ることができるという利点があります。このため、数値データセットに適したモジュールと言えます。第二に、我々の問題文は、大学の点数と成績に大きく依存する入学者選抜の予測を扱っています。つまり、このようなデータセットには、数値データセットであるランダムフォレストが最適なソリューションとなります。Javaによる実装では、ランダムフォレストのソースコードが含まれている weka.classifier.trees
パッケージを使用します。このライブラリは、我々のモデルを実装し、データの妥当性と精度をチェックするための評価プロセスを実行するために使用されます。
GridDBにデータを書き込む
長期保存データベースにデータを書き込むことは、再利用性とモデルのアクセシビリティにとって重要です。このセットは List<Row>
データ型を使用して行うことができ、リストに値を格納することから始めて、後で GridDB データベースに追加します。
本節で説明する作業を行うために、以下のコードを使用しました。
row.setInteger(0, i);
row.setInteger(1, serial);
row.setInteger(2, gre);
row.setInteger(3, toefl);
row.setInteger(4, rating);
row.setFloat(5, sop);
row.setFloat(6, lor);
row.setFloat(7, cgpa);
row.setInteger(8, research);
row.setFloat(9, admitclass);
rowList.add(row);
GridDBにデータを格納する
データを GridDB データベースに格納するために、適切なカラム名とデータ型を使用する必要があります。まず、この作業は属性のデータ型を調べ、データベースでその型をマッピングすることで達成されます。データセットを調査した結果、すべての属性が数値であることが分かりました。つまり、文字列やキャラクターをマッピングする必要はありません。次に、整数値と浮動小数点数を区別する必要があります。このステップは非常に簡単で、大学の成績はすべて浮動小数点数で、その他の値は整数値であることが分かります。
本節で説明する作業を行うために、以下のコードを使用しました。
ContainerInfo containerInfo = new ContainerInfo();
List<columninfo> columnList = new ArrayList<columninfo>();
columnList.add(new ColumnInfo("key", GSType.INTEGER));
columnList.add(new ColumnInfo("Serial No.", GSType.INTEGER));
columnList.add(new ColumnInfo("GRE Score", GSType.INTEGER));
columnList.add(new ColumnInfo("TOEFL Score", GSType.INTEGER));
columnList.add(new ColumnInfo("University Rating", GSType.INTEGER));
columnList.add(new ColumnInfo("SOP", GSType.FLOAT));
columnList.add(new ColumnInfo("LOR", GSType.FLOAT));
columnList.add(new ColumnInfo("CGPA", GSType.FLOAT));
columnList.add(new ColumnInfo("Research", GSType.INTEGER));
columnList.add(new ColumnInfo("Chance of Admit", GSType.FLOAT));
GridDBからデータを取得する
コードの有効性をテストすることは、私たちの主要な優先事項の1つであるべきです。このセクションでは、データがデータベースに正しく保存され、問題がないことを検証します。このタスクを実行するために、データベースのすべての値を返す SELECT
クエリを使用します。
本節で説明する作業を行うために、以下のコードを使用しました。
Container, Row> container = store.getContainer(containerName);
if (container == null) {
throw new Exception("Container not found.");
}
Query<row> query = container.query("SELECT * ");
RowSet<row> rowset = query.fetch();
データベースからデータを取り出したら、次はデータを表示する作業です。この作業には、Javaのループとprint文が必要です。
以下のコードで、GridDBのデータを表示しました。
while (rowset.hasNext()) {
Row row = rowset.next();
int serial = row.getInt(0);
float gre = row.getFloat(1);
float toefl = row.getFloat(2);
float rating = row.getFloat(3);
int sop = row.getInt(4);
int lor = row.getInt(5);
int cgpa = row.getInt(6);
int research = row.getInt(7);
float admitclass =row.getFloat(8);
System.out.println(admitclass);
}
ランダムフォレストの構築
ランダム・フォレスト・モデルは、ある学生の大学への入学予測を決定するために、Java プログラムで呼び出されます。このコードではまずランダムフォレストオブジェクトを初期化します。次に、デフォルトの WEKA
パラメータを用いて、ランダムフォレストアルゴリズムを設定します。
本節で説明する作業を行うために、以下のコードを使用しました。
RandomForest randomForest = new RandomForest();
String[] parameters = new String[14];
parameters[0] = "-P";
parameters[1] = "100";
parameters[2] = "-I";
parameters[3] = "100";
parameters[4] = "-num-slots";
parameters[5] = "1";
parameters[6] = "-K";
parameters[7] = "0";
parameters[8] = "-M";
parameters[9] = "1.0";
parameters[10] = "-V";
parameters[11] = "0.001";
parameters[12] = "-S";
parameters[13] = "1";
randomForest.setOptions(parameters);
パラメータを定義したら、それをモデルに設定する必要があります。次のステップでは、学習データを使ってモデルを学習させます。最後に、入場データセットを使ってモデルを学習させたら、その結果を評価することでコードを終了します。
本節で説明する作業を行うために、以下のコードを使用しました。
randomForest.setOptions(parameters);
randomForest.buildClassifier(datasetInstances);
Evaluation evaluation = new Evaluation(datasetInstances);
evaluation.crossValidateModel(randomForest, datasetInstances, numFolds, new Random(1));
System.out.println(evaluation.toSummaryString("\nResults\n======\n", true));
コードのコンパイルと実行
Javaプログラムをコンパイルするために、GridDBフォルダを探し、コマンドラインを使用してプログラムを実行します。Javaプログラムをコンパイルするために、javac
コマンドを使用します。コンパイルが完了したら、次に java
コマンドを使用してコードを実行します。
具体的には以下のようなコマンドです。
javac gsSample/randomForest.java
java gsSample/randomForest
結論と結果
ランダムフォレストのJavaプログラムをコンパイルした後、私たちのモデルの性能を理解するために、結果を調べる必要があります。結果の要約には、モデルの精度が含まれています。さらに、結果のサマリーには、属性数、インスタンス数、モデルの情報の精度が含まれています。我々のランダムフォレストは、わずか400個のインスタンスを使用して、91.05%の精度に達しました。説明すると、我々のデータセットのインスタンス数は、大規模なデータセットを表しているわけではないのです。つまり、より大きなデータセットがあれば、より広範囲に精度を向上させることができるのです。さらに、GridDBデータベースは動的であり、我々のモデルで実行する追加インスタンスを容易に格納することができます。
以下の要約出力は、私たちの結果と評価情報を表現したものです。
=== Run information ===
Relation: Admission_Predict
Instances: 400
Attributes: 9
Serial No.
GRE Score
TOEFL Score
University Rating
SOP
LOR
CGPA
Research
Chance of Admit
=== Summary ===
Correctly Classified Instances 91.05 %
Incorrectly Classified Instances 8.96 %
Mean absolute error 0.0896
Relative absolute error 9.19 %
Root relative squared error 34.3846 %
Total Number of Instances 400
ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb