Javaによる線形回帰

はじめに

線形回帰は、単純線形回帰とも呼ばれ、従属変数と1つの独立変数の間の関係をモデル化する回帰アルゴリズムです。線形回帰モデルは、線形または傾斜した直線である関係を示すので、単純線形回帰という名前になりました。

線形回帰では、従属変数は実数または連続値でなければなりません。しかし、独立変数は連続値またはカテゴリ値で測定できます。

単純線形回帰アルゴリズムの主な目的は以下の2つです。

  • 職歴と給与、収入と支出など、2つの与えられた変数の間の関係をモデル化します。

  • 新しいオブザベーションを予測します。例えば、1年間の投資に基づいて企業が生み出す収益額、気温に基づく天気予報、などです。

単純な線形回帰は、以下の式で表すことができます。

Y = a + bX

ここで、

Y- は従属変数である。

a- は回帰直線の切片である。

b- は直線の傾き。

X- は独立変数。

Javaで線形回帰を実装する

このセクションでは、Javaで単純回帰アルゴリズムを実装する方法について説明します。使用するデータセットには、従属変数である salary と独立変数である experience の2つの変数があります。

このデータセットから線形回帰モデルを構築します。そして、このモデルを使って、個人の経験に基づいてその人の給料を予測します。

GridDBにデータを格納する

データはCSVファイル Salary_Data.csv に格納されていますが、これをGridDBコンテナに書き込む必要があります。CSVファイルから直接データを読み込むこともできますが、GridDBを利用するとクエリのパフォーマンスが向上します。

データセットには、yearsExperiencesalary という2つのカラムがあります。

まずは、必要なライブラリをインポートしてみましょう。

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;
import java.io.File;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

次に、使用する GridDB コンテナを表す静的 Java クラスを作成します。

public static class ExperienceSalary {
     @RowKey float yearsExperience;
     Double  salary; 
}    

上記のクラスは、2つのカラムを持つSQLテーブルと同等です。2つの変数は、GridDBコンテナのカラムを表します。

それでは、Java を使って GridDB に接続してみましょう。以下のように、GridDB をインストールした際の認証情報を使用します。

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

GridDBのExperienceSalaryコンテナを使用するので、これを選択します。

Collection<String, ExperienceSalary> coll = store.putCollection("col01", ExperienceSalary.class);

コンテナのインスタンスを作成し、collという名前を付けました。このインスタンス名を使って、コンテナを参照することにします。

それでは、Salary_Data.csv ファイルからデータを読み込んで、GridDB コンテナに書き込んでみましょう。


                File file1 = new File("Salary_Data.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String yearsExperience = dataList[0];
                        String salary = dataList[1];

                        ExperienceSalary es = new ExperienceSalary();

                        es.yearsExperience = Float.parseFloat(yearsExperience);
                        es.salary = Double.parseDouble(salary);

                        coll.append(es);
                 }

静的クラスのオブジェクトを作成し、es という名前を付けているのがわかります。このオブジェクトは、GridDB コンテナに追加されます。

GridDBからデータを取得する

線形回帰モデルを実装するためにデータを使用する必要があるので、GridDBコンテナからデータを取り出してみましょう。


            Query<experiencesalary> query = coll.query("select *");
            RowSet<experiencesalary> rs = query.fetch(false);
            RowSet res = query.fetch();
  

ここでは、select *クエリを使って、GridDBコンテナに格納されているすべてのデータを取り出しています。

線形回帰モデルの構築

さて、データが揃ったので、いよいよデータセットの2つの変数の間に相関があるかどうかを教えてくれる線形回帰モデルを実装することにしましょう。

まず、使用するライブラリをインポートしましょう。

import java.io.IOException;

import weka.classifiers.Evaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader; 

import java.io.BufferedReader;
import java.io.FileReader;

次に、データセット用のバッファードリーダーを作成しましょう。また、データセットのインスタンスも作成します。

            BufferedReader bufferedReader
                = new BufferedReader(
                    new FileReader(res));
 
            // Create dataset instances
            Instances datasetInstances
                = new Instances(bufferedReader);

それでは、このデータセットを使って線形回帰分類器を実装してみましょう。


        datasetInstances.setClassIndex(datasetInstances.numAttributes()-1);
        Classifier classifier = new weka.classifiers.functions.LinearRegression();
        classifier.buildClassifier(datasetInstances);

予測をする

さて、モデルの準備ができたので、これを使って予測をしてみましょう。今回は、データセットの最後のインスタンスの給料を予測します。


        // Predicting the salary 
        Instance predSalary = datasetInstances.lastInstance();
        double sal = classifier.classifyInstance(predSalary);
        System.out.println("The salary is : "+sal);

モデルのコンパイルと実行

モデルをコンパイルして実行するには、Weka API が必要です。以下のURLからダウンロードすることができます。

http://www.java2s.com/Code/Jar/w/weka.htm

次に、gsadm ユーザでログインします。作成した .java ファイルを GridDB の bin フォルダに移動し、以下のパスに配置します。

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

次に、Linux端末で以下のコマンドを実行し、gridstore.jarファイルのパスを設定します。

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

次に、以下のコマンドを実行して .java ファイルをコンパイルします。

javac -cp weka-3-7-0/weka.jar LinearRegression.java

以下のコマンドを実行して生成された .class ファイルを実行します。

java -cp .:weka-3-7-0/weka.jar LinearRegression

このモデルは122017.00を返しましたが、これはそのインスタンスの実際の給与、つまり121872.00に非常に近いものでした。

Leave a Reply

Your email address will not be published.