Email Spam Classification using Java and GridDB

Introduction

In email spam classification, the task is to determine whether an email is a spam or not spam. Although this is mostly determined by keywords in the subject or message, there are cases where this isn’t easy, and additional parts of the email should be considered. One way to solve this problem is by gathering example spam and non-spam emails, then train a machine learning model.

We will use a dataset that contains over 4000 spam (1) and non-spam (0) emails.

The dataset is in a CSV format. Each email has 48 features that describe the frequency of occurrence of particular words in the email: 6 features describe the occurrence of particular characters, 3 features describe the occurrence of capital letters, and the last feature that shows the spam/non-spam label. Since the emails will be categorized as either spam or not spam, this is a binary classification task.

We will follow the steps given below to build the binary classifier:

Step 1: Move the CSV data into GridDB.

Step 2: Read the GridDB data and create an in-memory dataset.

Step 3: Create a neural network model to perform binary classification.

Step 4: Use the loaded dataset to train the neural network model.

Step 5: Test the performance of the trained model.

Import Packages

Let’s begin by importing the packages to be used:

import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.Scanner;

import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.RowSet;

Write Data into GridDB

The CSV data has been stored in a file named spam.csv. Our goal is to move this data into a GridDB container. Let’s create a static class to represent the container schema:

 public static class Emails{
    
         @RowKey String word1;
     String word2,word3,word4,word5,word6,word7,word8,word9,word10,word11,word12,word13,word14,word15,word16,word17,word18,word19,word20,word21,word22,word23,word24,word25,word26,word27,word28,word29,word30,word31,word32,word33,word34,word35,word36,word37,word38,word39,word40,word41,word42,word43,word44,word45,word46,word47,word48,char1,char2,char3,char4,char5,char6,capitalAvg,capitalLongest,capitalTotal,isSpam;
    
    }

The above class will create a GridDB container, which is similar to a SQL table, with 58 columns.

We can now connect to our GridDB instance. We will create an instance of GridDB that features the details of our GridDB installation:

        Properties props = new Properties();
        props.setProperty("notificationAddress", "239.0.0.1");
        props.setProperty("notificationPort", "31999");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        GridStore store = GridStoreFactory.getInstance().getGridStore(props);

Let us select the Emails container as we want to work on it:

Collection<String, Emails> coll = store.putCollection("col01", Emails.class);

The above code creates an instance of the Emails container and gives it the name coll. We will be using this name to refer to the container.

Let us now read data from the spam.csv file and store it in GridDB:

File file1 = new File("spam.csv");
                Scanner sc = new Scanner(file1);
                String data = sc.next();
 
                while (sc.hasNext()){
                        String scData = sc.next();
                        String dataList[] = scData.split(",");
                        String word1 = dataList[0];
                        String word2 = dataList[1];
                        String word3 = dataList[2];
                        String word4 = dataList[3];
                        String word5 = dataList[4];
                        String word6 = dataList[5];
                        String word7 = dataList[6];
                        String word8 = dataList[7];
                        String word9 = dataList[8];
                        String word10 = dataList[9];
                        String word11 = dataList[10];
                        String word12 = dataList[11];
                        String word13 = dataList[12];
                        String word14 = dataList[13];
                        String word15 = dataList[14];
                        String word16 = dataList[15];
                        String word17 = dataList[16];
                        String word18 = dataList[17];
                        String word19 = dataList[18];
                        String word20 = dataList[19];
                        String word21 = dataList[20];
                        String word22 = dataList[21];
                        String word23 = dataList[22];
                        String word24 = dataList[23];
                        String word25 = dataList[24];
                        String word26 = dataList[25];
                        String word27 = dataList[26];
                        String word28 = dataList[27];
                        String word29 = dataList[28];
                        String word30 = dataList[29];
                        String word31 = dataList[30];
                        String word32 = dataList[31];
                        String word33 = dataList[32];
                        String word34 = dataList[33];
                        String word35 = dataList[34];
                        String word36 = dataList[35];
                        String word37 = dataList[36];
                        String word38 = dataList[37];
                        String word39 = dataList[38];
                        String word40 = dataList[39];
                        String word41 = dataList[40];
                        String word42 = dataList[41];
                        String word43 = dataList[42];
                        String word44 = dataList[43];
                        String word45 = dataList[44];
                        String word46 = dataList[45];
                        String word47 = dataList[46];
                        String word48 = dataList[47];
                        String char1 = dataList[48];
                        String char2 = dataList[49];
                        String char3 = dataList[50];
                        String char4 = dataList[51];
                        String char5 = dataList[52];
                        String char6 = dataList[53];
                        String capitalAvg =dataList[54];
                        String capitalLongest = dataList[55];
                        String capitalTotal = dataList[56];
                        String isSpam = dataList[57];

                        
                        
                        Emails emails = new Emails();
                        
                        emails.word1 =word1;
                        emails.word2 =word2;
                        emails.word3 =word3;
                        emails.word4 =word4;
                        emails.word5 =word5;
                        emails.word6 =word6;
                        emails.word7 =word7;
                        emails.word8 =word8;
                        emails.word9 =word9;
                        emails.word10 =word10;
                        emails.word11 =word11;
                        emails.word12 =word12;
                        emails.word13 =word13;
                        emails.word14 =word14;
                        emails.word15 =word15;
                        emails.word16 =word16;
                        emails.word17 =word17;
                        emails.word18 =word18;
                        emails.word19 =word19;
                        emails.word20 =word20;
                        emails.word21 =word21;
                        emails.word22 =word22;
                        emails.word23 =word23;
                        emails.word24 =word26;
                        emails.word27 =word27;
                        emails.word28 =word28;
                        emails.word29 =word29;
                        emails.word30 =word30;
                        emails.word31 =word31;
                        emails.word32 =word32;
                        emails.word33 =word33;
                        emails.word34 =word34;
                        emails.word35 =word35;
                        emails.word36 =word36;
                        emails.word37 =word37;
                        emails.word38 =word38;
                        emails.word39 =word39;
                        emails.word41 =word41;
                        emails.word42 =word42;
                        emails.word43 =word43;
                        emails.word44 =word44;
                        emails.word45 =word45;
                        emails.word46 =word46;
                        emails.word47 =word47;
                        emails.word48 =word48;
                        emails.char1 = char1;
                        emails.char2 = char2;
                        emails.char3 = char3;
                        emails.char4 = char4;
                        emails.char5 = char5;
                        emails.char6 = char6;
                        emails.capitalAvg = capitalAvg;
                        emails.capitalLongest = capitalLongest;
                        emails.capitalTotal = capitalTotal;
                        emails.isSpam = isSpam;
    
                        
                        coll.append(emails);
                 }

We have created an object named emails and appended it to the GridDB container.

Query the Data from GridDB

Use the following code to pull the data from GridDB:

Query<emails> query = coll.query("select *");
                RowSet</emails><emails> rs = query.fetch(false);
            RowSet res = query.fetch();</emails>

The select *helps us to select all the data stored in the container.

Build the Binary Classifier

First, let’s import all the Java libraries that we will need to build the binary classifier:

import deepnetts.data.DataSets;
import deepnetts.eval.Evaluators;
import deepnetts.data.norm.MaxNormalizer;
import javax.visrec.ml.eval.EvaluationMetrics;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.LossType;
import deepnetts.net.FeedForwardNetwork;
import deepnetts.util.DeepNettsException;
import java.io.IOException;
import javax.visrec.ml.classification.BinaryClassifier;
import javax.visrec.ml.ClassificationException;
import javax.visrec.ml.data.DataSet;
import deepnetts.data.MLDataItem;
import javax.visrec.ri.ml.classification.FeedForwardNetBinaryClassifier;

We can now build a binary classifier that uses a feed-forward neural network. Here is the code for this:

DataSet emailsDataSet= DataSets(res, 57, 1, true); 

// create a feed forward neural network using builder
FeedForwardNetwork nn = FeedForwardNetwork.builder()
                               .addInputLayer(57)                                                                                                                                        
                               .addFullyConnectedLayer(15) 
                               .addOutputLayer(1, ActivationType.SIGMOID)                                           
                               .lossFunction(LossType.CROSS_ENTROPY) 
                               .build();

The feed-forward neural network is a machine learning algorithm which can be represented as a graph:

Every node in the graph does a calculation to transform the input. Each node applies a function to the input it receives from its preceeding node, and then sends it to the next node. In the above code, we are using sigmoid as the actovation function. The nodes are organized into groups known as layers.

Split the Data into Training and Test Sets

The goal is to train a machine learning model using a subset of our emails dataset and use the other subset of the dataset to test the accuracy of the model. We will split the data into two sets, training and test sets (60% for training and 40% for testing).

The following code will help us to split the data into these two sets:

// split the data
DataSet[] trainTestData = emailsDataSet.split(0.6, 0.4);


// Normalize the data
MaxNormalizer norm = new MaxNormalizer(trainTestData[0]); 
norm.normalize(trainTestData[0]); // normalize the training set
norm.normalize(trainTestData[1]); // normalize the test set
DataSet trainingSet = trainTestData[0];
DataSet testSet = trainTestData[1];

Note that in the above code, after splitting the data, we have normalized it.

Train the Neural Network

Now that the data is trained, it is time to train the neural network model. The feed-forward neural network will be trained using the backpropagation algorithm. The Deep Netts API has implemented this algorithm in the Backpropagation Trainer class. We will use its NeuralNetwork.getTrainer() method to create an instance of the trainer that will be used for training the parent network. We will also set a number of parameters for the algorithm like the maximum number of epochs (the number of passes through the training set) and the learnimg rate.

Use the following code:

// configure the trainer
      nn.getTrainer().setMaxError(0.03f)
                              .setMaxEpochs(10000)
                              .setLearningRate(0.001f);

      // start the training 
      nn.train(trainingSet);

When executed, the code will return something similar to this:

Epoch:1, Time:72ms, TrainError:0.66057104, TrainErrorChange:0.66057104, TrainAccuracy: 0.6289855
Epoch:2, Time:18ms, TrainError:0.6435114, TrainErrorChange:-0.017059624, TrainAccuracy: 0.65072465
Epoch:3, Time:17ms, TrainError:0.6278175, TrainErrorChange:-0.015693903, TrainAccuracy: 0.6786232
Epoch:4, Time:14ms, TrainError:0.60796565, TrainErrorChange:-0.019851863, TrainAccuracy: 0.726087
Epoch:5, Time:15ms, TrainError:0.58832765, TrainErrorChange:-0.019638002, TrainAccuracy: 0.74746376

Thus, the model returns the results of each epoch.

Testing the Classifier

We need to be sure that the classifier will give accurate results when classifying new emails as either spam or not spam. Thus, we should test it.

The classifier will be tested using the test data that has not been used in training the model. It will calculate the correct and wrong classifications, and calculate other metrics to help you understand the classifier.

The following code can help us to test the classifier:

EvaluationMetrics metrics = Evaluators.evaluateClassifier(nn, testSet);
System.out.println(metrics);

The code will return metrics such as accuracy, precision, recall, and F1Score of the classifier.

Using the Trained Model

We need to use the trained classifier to classify new emails. The following code demonstrates this:

// create a binary classifier using the trained network
        BinaryClassifier bc = new FeedForwardNetBinaryClassifier(nn);        
        // get the test email as an array of features
        float[] testEmail = testSet.get(0).getInput().getValues();

        // get the probability score that the email is spam
        Float output = bc.classify(testEmail);
        System.out.println("Spam probability: "+output);

We have used the FeedForwardNetBinaryClassifier class to wrap the trained network. The BinaryClassifier is an interface that uses generics to specify the inputs of a binary classifier, which is an array of email features in our case. We have then passed an email to be classified from the test dataset as an array of email features. The array has then been passed to the classify method of the binary classifier.

Compile and Execute the Code

Login as the gsadm user. Move your .java file to the bin folder of your GridDB located in the following path:

/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin

Next, run the following command on your Linux terminal to set the path for the gridstore.jar file:

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar

Next, run the following command to compile your .java file:

javac EmailSpamFilter.java

Run the .class file that is generated by running the following command:

java EmailSpamFilter

The result will be a probability score, which shows the probability of an email being a spam.

If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.