多層パーセプトロンとは、入力データを適切な出力に対応させるフィードフォワードの人工ニューラルネットワークモデルを指します。
多層パーセプトロンは、有向グラフのノードを何層にも重ねたもので、どの層も次の層と完全につながっています。
今回は、パーセプトロンのニューラルネットワークモデルをJavaとGridDBで実装します。このモデルの目的は、住宅価格が住宅価格の中央値より上か下かを予測することです。
完全なソースコードと csv
ファイルは、Github リポジトリで見ることができます。https://github.com/griddbnet/Blogs/tree/perceptron_neural_network
データ説明
使用するデータセットには11個の特徴があります。データセットの最初の10個の属性が入力の特徴となります。これらの入力はそれぞれ家の特徴を記述しています。
データセットの最後の属性は、これから予測する特徴で、住宅価格が中央値より上か下かを記述します。値が1であれば、住宅価格が中央値より高いことを意味し、値が0であれば、住宅価格が中央値より低いことを意味します。
データは、housepricedata.csv
という名前のCSVファイルに格納されています。
GridDBにデータを格納する
ファイルからデータを利用することもできますが、GridDBはCSVファイルよりも多くの利点を備えています。例えば、GridDBはCSVファイルよりも高速にクエリ結果を返すことができます。
そのため、GridDBにデータを保存することを選択しました。
まず、GridDBにデータを格納するためのライブラリをインポートしましょう。
import java.io.File;
import java.util.Scanner;
import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStoreFactory;
GridDBは、データをコンテナにグループ化します。データを格納するGridDBコンテナを表す静的なJavaクラスを作成しましょう。
public static class Perceptron {
@RowKey String lotArea;
String overallQual;
String overallCond;
String totalBsmtSF;
String fullBath;
String halfBath;
String bedroomAbvGr;
String totRmsAbvGrd;
String fireplaces;
String garageArea;
String aboveMedianPrice;
}
上記の静的クラスをSQLのテーブルと見立ててください。各変数は、GridDB コンテナ内の 1 つのカラムを表します。
これで、JavaからGridDBに接続できるようになりました。これは、認証プロセスが正常に実行されれば可能です。接続は以下のコードで行うことができます。
Properties props = new Properties();
props.setProperty("notificationAddress", "239.0.0.1");
props.setProperty("notificationPort", "31999");
props.setProperty("clusterName", "defaultCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
GridDBの設定に応じて、正しい認証情報を使用することを確認してください。
では、データを挿入するコンテナを選択してみましょう。
Collection<String, Perceptron> coll = store.putCollection("col01", Perceptron.class);
コンテナのインスタンスが作成されました。このインスタンスは、コンテナを参照するために使用することができます。
それでは、CSVファイルからデータを取り出し、GridDBに挿入してみましょう。
File file1 = new File("housepricedata.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();
while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");
String lotArea = dataList[0];
String overallQual = dataList[1];
String overallCond = dataList[2];
String totalBsmtSF = dataList[3];
String fullBath = dataList[4];
String halfBath = dataList[5];
String bedroomAbvGr = dataList[6];
String totRmsAbvGrd = dataList[7];
String fireplaces = dataList[8];
String garageArea = dataList[9];
String aboveMedianPrice = dataList[10];
Perceptron pc = new Perceptron();
pc.lotArea = Integer.parseInt(lotArea);
pc.overallQual = Integer.parseInt(overallQual);
pc.overallCond = Integer.parseInt(overallCond);
pc.totalBsmtSF = Integer.parseInt(totalBsmtSF);
pc.fullBath = Integer.parseInt(fullBath);
pc.halfBath = Integer.parseInt(halfBath);
pc.bedroomAbvGr = Integer.parseInt(bedroomAbvGr);
pc.totRmsAbvGrd = Integer.parseInt(totRmsAbvGrd);
pc.fireplaces = Integer.parseInt(fireplaces);
pc.garageArea = Integer.parseInt(garageArea);
pc.aboveMedianPrice = Integer.parseInt(aboveMedianPrice);
coll.append(pc);
}
上記のコードでは、CSV ファイルからデータを取り出し、GridDB コンテナに挿入しています。
データを取得する
このデータを使って、パーセプトロンのニューラルネットワークモデルを実装してみたいと思います。そこで、GridDBコンテナからデータを取り出してみましょう。
Query<perceptron> query = coll.query("select *");
RowSet<perceptron> rs = query.fetch(false);
RowSet res = query.fetch();
select *
は、GridDBコンテナに格納されているデータをすべて選択するのに役立ちました。
パーセプトロンニューラルネットワークモデルの適合
このデータセットを使って機械学習モデルを作ることができるようになりました。まず、これに使うライブラリをインポートしましょう。
import java.io.BufferedReader;
import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;
import java.math.BigDecimal;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.MultilayerPerceptron;
データセットのバッファードリーダーを作成します。
BufferedReader bufferedReader
= new BufferedReader(
new FileReader(res));
// Create dataset instances
Instances datasetInstances
= new Instances(bufferedReader);
次に、PerceptronNeuralNetwork
クラスのインスタンスを作成します。
datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);
//Instance of NN
PerceptronNeuralNetwork mlp = new PerceptronNeuralNetwork();
それでは、ニューラルネットワークのパラメータを設定しましょう。
//Setting Parameters
mlp.setLearningRate(0.1);
mlp.setMomentum(0.2);
mlp.setTrainingTime(2000);
mlp.setHiddenLayers("3");
上記のコードにより、モデルの学習率、運動量、学習時間、隠れ層の数などのパラメータを設定することができました。
それでは、分類器を作ってみましょう。
mlp.buildClassifier(datasetInstances);
モデルを評価する
では、このモデルがどのように動作するかを評価しましょう。評価には Evaluation()
関数を使います。
Evaluation eval = new Evaluation(datasetInstances);
eval.evaluateModel(mlp, datasetInstances);
System.out.println(eval.toSummaryString()); //Summary of Training
これで評価指標を表示できるようになりました。
//display metrics
System.out.println("Correlation: "+eval.correlationCoefficient());
System.out.println("Mean Absolute Error: "+new BigDecimal(eval.meanAbsoluteError()));
System.out.println("Root Mean Squared Error: "+eval.rootMeanSquaredError());
System.out.println("Relative Absolute Error: "+eval.relativeAbsoluteError()+"%");
System.out.println("Root Relative Squared Error: "+eval.rootRelativeSquaredError()+"%");
System.out.println("Instances: "+eval.numInstances());
予測する
データセットの最後のインスタンスを使って、予測をしてみましょう。目標は、住宅価格が中央値より上か下かを知ることです。結果が1であれば、住宅価格が中央値より上であることを意味します。0 は家の値段が中央値より下であることを意味します。
Instance pred = datasetInstances.lastInstance();
double answer = mlp.classifyInstance(pred);
System.out.println(answer);
モデルのコンパイルと実行
モデルをコンパイルして実行するには、Weka API が必要です。以下のURLからダウンロードしてください。
http://www.java2s.com/Code/Jar/w/weka.htm
次に、gsadm
ユーザでログインします。作成した .java
ファイルを GridDB の bin
フォルダに移動しましょう。移動先は以下の通りです。
/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin
Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。
export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar
次に、以下のコマンドを使用して .java
ファイルをコンパイルします。
javac -cp weka-3-7-0/weka.jar PerceptronNeuralNetwork.java
以下のコマンドを実行して生成された .class ファイルを実行します。
java -cp .:weka-3-7-0/weka.jar PerceptronNeuralNetwork
モデルは予測値として0
を返しました。これは、住宅価格が中央値より低いことを意味します。
ブログの内容について疑問や質問がある場合は Q&A サイトである Stack Overflow に質問を投稿しましょう。 GridDB 開発者やエンジニアから速やかな回答が得られるようにするためにも "griddb" タグをつけることをお忘れなく。 https://stackoverflow.com/questions/ask?tags=griddb