本記事では、世界中の多くの人々のライフスタイルを動かしている糖尿病という健康管理の悩みを取り上げます。この記事では、予測システムを作成するための機械学習モデルの使用方法を説明します。このモデルは、ランダムフォレストを使用して、患者が糖尿病であるかどうかを予測します。この記事では、データベースGridDBをセットアップするために必要な要件を概説します。その後、我々のデータセットとモデルについて簡単に説明します。最後に、結果を解釈し、我々の結論を出します。
必要条件
GridDBデータベース・ストレージ・システムは、ランダム・フォレスト・モジュールを構築する際のバックアップメモリ・リソースとなります。GridDBを完全に機能させるには、ダウンロードし、オペレーティングシステムに設定する必要があります。
必ず以下のコマンドを実行して、必要な環境変数を更新してください。
export GS_HOME=$PWD
export GS_LOG=$PWD/log
export PATH=${PATH}:$GS_HOME/bin
export CLASSPATH=$CLASSPATH:/usr/share/java/gridstore.jar
export CLASSPATH=${CLASSPATH}:/usr/share/java/weka.jar
Java クラスでは、GridDB クラスタが動作していることを確認する必要があります。このタスクは、GridDB コンテナを作成することから始める必要があります。このタスクでは、データセットのカラムで表現されるコンテナスキーマを作成する必要があります。
上記の作業を実現するために、以下のようなコードを使用します。
// Manage connection to GridDB
Properties properties = new Properties();
properties.setProperty("notificationAddress", "239.0.0.1");
properties.setProperty("notificationPort", "31999");
properties.setProperty("clusterName", "cluster");
properties.setProperty("database", "public");
properties.setProperty("user", "admin");
properties.setProperty("password", "admin");
// Get Store and Container
GridStore store = GridStoreFactory.getInstance().getGridStore(properties);
store.getContainer("newContainer");
String containerName = "mContainer";
// Define container schema and columns
ContainerInfo containerInfo = new ContainerInfo();
List<columninfo> columnList = new ArrayList<columninfo>();
columnList.add(new ColumnInfo("key", GSType.INTEGER));
columnList.add(new ColumnInfo("Pregnancies", GSType.INTEGER));
columnList.add(new ColumnInfo("Glucose", GSType.INTEGER));
columnList.add(new ColumnInfo("BloodPressure", GSType.INTEGER));
columnList.add(new ColumnInfo("SkinThickness", GSType.INTEGER))
columnList.add(new ColumnInfo("Insulin", GSType.INTEGER));
columnList.add(new ColumnInfo("BMI", GSType.FLOAT));
columnList.add(new ColumnInfo("DiabetesPedigreeFunction", GSType.FLOAT));
columnList.add(new ColumnInfo("Age", GSType.INTEGER));
columnList.add(new ColumnInfo("Outcome", GSType.INTEGER));
containerInfo.setColumnInfoList(columnList);
containerInfo.setRowKeyAssigned(true);
Collection<Void, Row> collection = store.putCollection(containerName, containerInfo, false);
List<row> rowList = new ArrayList<row>();
ランダムフォレストのJavaプログラムでは、4つのライブラリグループからクラスをインポートする必要があります。
java.util
: used to read our dataset and write it in our database.java.io
: used for input and output operations that will be used in our program.com.toshiba.mwcloud.gs
: used to set up and operate the GridDB database.weka.classifier.trees
: used to implement and configure our random forest module.
上記のタスクを実装するために使用されるJavaコードは以下の通りです。
// ---------- Java Util ---------
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import java.util.Scanner;
// ---------- Java IO ---------
import java.io.IOException;
import java.io.File;
import java.io.BufferedReader;
import java.io.FileReader;
// ---------- GridDB ---------
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.ColumnInfo;
import com.toshiba.mwcloud.gs.Container;
import com.toshiba.mwcloud.gs.ContainerInfo;
import com.toshiba.mwcloud.gs.GSType;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.Row;
import com.toshiba.mwcloud.gs.RowSet;
// ----------- Weka ---------
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.Evaluation;
データセット
ランダムフォレストアルゴリズムを実装するために、ある個人が糖尿病かどうかを診断測定値に基づいて予測するdiabetesデータセットを使用します。このデータセットは9個の属性と768個のインスタンスから構成されます。
本データセットが対象とする属性は以下の通りです。
- Pregnancies: 妊娠の履歴データを特定するために用いる数値。
- Glucose: グルコースレベルを数値化するために用いる数値。
- Blood Pressure: 血圧を数値化したもの。
- Skin Thickness: 皮膚の厚さを数値化するために用いる数値。
- Insulin: インスリンレベルを数値化するために使用する数値。
- BMI: Body Mass Indexを数値化するために用いる数値。
- Diabetes Pedigree Function: 糖尿病血統関数を定量化するために用いる数値。
- Age: 患者の年齢を年単位で定量化するために用いられる数値。
- Outcome: 糖尿病の有無(1)を判定するための2値。
以下、データセットの抜粋をご覧ください。
Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome
6,148,72,35,0,33.6,0.627,50,1
1,85,66,29,0,26.6,0.351,31,0
8,183,64,0,0,23.3,0.672,32,1
1,89,66,23,94,28.1,0.167,21,0
0,137,40,35,168,43.1,2.288,33,1
5,116,74,0,0,25.6,0.201,30,0
3,78,50,32,88,31,0.248,26,1
10,115,0,0,0,35.3,0.134,29,0
2,197,70,45,543,30.5,0.158,53,1
8,125,96,0,0,0,0.232,54,1
Javaを使って糖尿病のデータセットファイルを読み込んで、次のステップの準備ができたことを確認する必要があります。この作業が完了したら、GridDBデータベースを使用して、このデータセットを長期保存することになります。
以下のコードでは、上記のステップを実装しています。
while (sc.hasNext()){
int i = 0;
Row row = collection.createRow();
String line = sc.next();
String columns[] = line.split(",");
int pregnancies = Integer.parseInt(columns[0]);
int glucose = Integer.parseInt(columns[1]);
int bloodpressure = Integer.parseInt(columns[2]);
int skinthickness = Integer.parseInt(columns[3]);
int insulin = Integer.parseInt(columns[4]);
float bmi = Float.parseFloat(columns[5]);
float diabetespedigreefunction = Float.parseFloat(columns[6]);
int age = Integer.parseInt(columns[7]);
int outcome = Integer.parseInt(columns[8]);
row.setInteger(0, i);
row.setInteger(1, pregnancies);
row.setInteger(2, glucose);
row.setInteger(3, bloodpressure);
row.setInteger(4, skinthickness);
row.setInteger(5, insulin);
row.setFloat(6, bmi);
row.setFloat(7, diabetespedigreefunction);
row.setInteger(8, age);
row.setInteger(9, outcome);
rowList.add(row);
i++;
}
Javaによるランダムフォレストアルゴリズムの実装
今回は、決定木を凌駕する分類モデルを採用します。ランダムフォレストアルゴリズムは、このアルゴリズムに付けられた名前です。なぜこのモデルが選ばれたのかを理解するためには、まず決定木と比較した場合の利点を説明する必要があります。まず、ランダムフォレストアルゴリズムは、多くの木が作られることを利用し、その結果を平均化することで、最高の精度を提供します。これにより、数値データセットに対するモジュールの適性を向上させることが出来ます。第二に、我々の問題文は糖尿病予測であり、診断測定に大きく依存しています。つまり、数値データセットであるため、ランダムフォレストが理想的なソリューションとなります。我々のJava実装では、ランダムフォレストのソースコードを含む weka.classifier.trees
パッケージが使用されます。このライブラリは、我々のモデルを実装し、データが有効で正確であることを確認するための評価プロセスを実行するために使用されます。
GridDBにデータを書き込む
長期保存可能なデータベースは、モデルの再利用性とアクセシビリティのために重要です。このセットは List<Row>
データ型を使用することで実現できます。これは、まず値をリストに格納し、後で GridDB データベースに追加するものです。
本節で説明する作業を行うために、以下のコードを使用しました。
row.setInteger(0, i);
row.setInteger(1, pregnancies);
row.setInteger(2, glucose);
row.setInteger(3, bloodpressure);
row.setInteger(4, skinthickness);
row.setInteger(5, insulin);
row.setFloat(6, bmi);
row.setFloat(7, diabetespedigreefunction);
row.setInteger(8, age);
row.setInteger(9, outcome);
rowList.add(row);
GridDBにデータを格納する
データを GridDB データベースに格納するために、正しいカラム名とデータ型を使用する必要があります。まず、属性のデータ型を調べて、データベースにマッピングすることで、この作業を完了させなければなりません。データセットを調査した結果、すべての特徴が数値であることが分かりました。つまり、文字列や文字を決定する必要がありません。次に、整数値と浮動小数点数を区別する必要があります。このステップは簡単で、BMIと糖尿病の血統関数のスコアはすべてfloatで、その他の値はintegerであることがわ分かります。
本節で説明する作業を行うために、以下のコードを使用しました。
// Define container schema and columns
ContainerInfo containerInfo = new ContainerInfo();
List<columninfo> columnList = new ArrayList<columninfo>();
columnList.add(new ColumnInfo("key", GSType.INTEGER));
columnList.add(new ColumnInfo("Pregnancies", GSType.INTEGER));
columnList.add(new ColumnInfo("Glucose", GSType.INTEGER));
columnList.add(new ColumnInfo("BloodPressure", GSType.INTEGER));
columnList.add(new ColumnInfo("SkinThickness", GSType.INTEGER))
columnList.add(new ColumnInfo("Insulin", GSType.INTEGER));
columnList.add(new ColumnInfo("BMI", GSType.FLOAT));
columnList.add(new ColumnInfo("DiabetesPedigreeFunction", GSType.FLOAT));
columnList.add(new ColumnInfo("Age", GSType.INTEGER));
columnList.add(new ColumnInfo("Outcome", GSType.INTEGER));
GridDBからデータを取得する
このセクションでは、コードの有効性を検証します。このセクションでは、データがデータベースに正しく保存されていること、そしてエラーがないことを確認します。このタスクを完了するために、データベースのすべての値を返す SELECT
クエリを使用します。
本節で説明する作業を行うために、以下のコードを使用しました。
Container, Row> container = store.getContainer(containerName);
if (container == null) {
throw new Exception("Container not found.");
}
Query<row> query = container.query("SELECT * ");
RowSet<row> rowset = query.fetch();
データベースからデータを取得したら、次はデータを印刷します。この作業には、ループと System.out.println()
関数が必要です。
以下のコードは、私たちのGridDBのデータを表示するために使用します。
// Print GridDB data
while ( rowset.hasNext() ) {
Row row = rowset.next();
int pregnancies = row.getInt(0);
int glucose = row.getInt(1);
int bloodpressure = row.getInt(2);
int skinthickness = row.getInt(3);
int insulin = row.getInt(4);
float bmi = row.getFloat(5);
float diabetespedigreefunction = row.getFloat(6);
int age = row.getFgetIntloat(7);
int outcome = row.getInt(8);
System.out.println(pregnancies);
System.out.println(glucose);
System.out.println(bloodpressure);
System.out.println(skinthickness);
System.out.println(insulin);
System.out.println(bmi);
System.out.println(diabetespedigreefunction);
System.out.println(age);
System.out.println(outcome);
}
ランダムフォレストの構築
このJavaプログラムでは、ランダムフォレストモデルが、患者の糖尿病診断の可能性を予測します。このコードではまずランダムフォレストオブジェクトを初期化します。次に標準的なWEKAパラメータを用いてランダムフォレストアルゴリズムを設定します。
本節で説明する作業を行うために、以下のコードを使用しました。
RandomForest randomForest = new RandomForest();
String[] parameters = new String[14];
parameters[0] = "-P";
parameters[1] = "100";
parameters[2] = "-I";
parameters[3] = "100";
parameters[4] = "-num-slots";
parameters[5] = "1";
parameters[6] = "-K";
parameters[7] = "0";
parameters[8] = "-M";
parameters[9] = "1.0";
parameters[10] = "-V";
parameters[11] = "0.001";
parameters[12] = "-S";
parameters[13] = "1";
randomForest.setOptions(parameters);
ランダムフォレストの引数が定義されたら、それらをモデルに入力する必要があります。次のステップは、トレーニングデータセットを使ってモデルをトレーニングすることです。最後に、糖尿病データセットを使ってモデルを準備した後、その結果を評価することでコードを完成させます。
本節で説明する作業を行うために、以下のコードを使用しました。
randomForest.setOptions(parameters);
randomForest.buildClassifier(datasetInstances);
Evaluation evaluation = new Evaluation(datasetInstances);
evaluation.crossValidateModel(randomForest, datasetInstances, numFolds, new Random(1));
System.out.println(evaluation.toSummaryString("\nResults\n======\n", true));
コードのコンパイルと実行
Javaプログラムをコンパイルするために、GridDB フォルダに移動して、コマンドラインから実行します。javac
コマンドを使用して、コンパイラを実行し、コードを構築します。プログラムをコンパイルしたら、java
コマンドで実行します。
以下のコマンドは、当社の説明を表現したものです。
javac gsSample/randomForest.java
java gsSample/randomForest.java
結論と結果
最後のセクションは、我々の結果を理解することです。患者が糖尿病かどうかを予測する際に、我々のランダム・フォレストがどの程度のパフォーマンスを発揮したかを簡単に消化するために、ランダムフォレストの出力をフィルタリングする必要があります。このセクションで注目する主な数値は、モデルの正確さです。この数値は、我々のモデルが実世界のシナリオでどの程度の性能を発揮できるかを表しています。我々のランダム・フォレストの糖尿病予測は、90.05%の精度に達しました。説明すると、これは90%の限界を超えているため、非常に高い精度と考えられます。今後の開発では、入出力データのインターフェースが容易で、データ検索のスピードが速いGridDBデータベースを全面的に利用する予定です。
以下の要約出力は、私たちの結果と評価情報を表現したものです。
=== Run information ===
Relation: diabetes
Instances: 768
Attributes: 9
Pregnancies
Glucose
BloodPressure
SkinThickness
Insulin
BMI
DiabetesPedigreeFunction
Age
Outcome
Test mode: 10-fold cross-validation
=== Classifier model (full training set) ===
RandomForest
Bagging with 100 iterations and base learner
Time taken to build model: 0.13 seconds
=== Summary ===
Correctly Classified Instances 90.05%
Incorrectly Classified Instances 9.95%
Mean absolute error 0.0995
Relative absolute error 9.51%
Total Number of Instances 768
ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb