Forecasting Timeseries Data Using DJL and GridDB

Introduction

Timeseries data is everywhere. From stock prices and weather patterns to sales figures and sensor data, it plays a crucial role in many aspects of our lives. Being able to forecast future values based on historical timeseries data is invaluable for making informed decisions. In this article, we will explore how to forecast timeseries data using the Deep Java Library (DJL) and GridDB.

What Makes a Time Series Data

Time Series data is characterized by its chronological order, where each data point is associated with a specific timestamp. This data format is prevalent in various domains, such as finance, healthcare, and IoT. To perform effective timeseries forecasting, we need tools and techniques that can capture and understand these temporal patterns.

Introducing Deep Java Library (DJL)

DJL is an open-source deep learning library designed to bring the power of deep learning to Java developers. It provides pre-trained models, tools for training custom models, and seamless integration with various deep learning frameworks like TensorFlow, PyTorch, and MXNet.

Deep Learning for Time Series Forecasting

Deep learning has shown remarkable success in solving complex timeseries forecasting problems. Models like DeepAR (Deep Autoregressive) can capture intricate temporal dependencies and generate accurate predictions. DJL offers an easy way to implement and deploy such models for your timeseries forecasting tasks.

Using DJL for Time Series Forecasting

To get started with DJL for time series forecasting, you need to add the following libraries to your project. Assuming your project was based on Maven, then you need to add these to your dependency section of your POM file.


        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.23.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.timeseries</groupId>
            <artifactId>timeseries</artifactId>
            <version>0.23.0</version>
        </dependency> 
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <!-- ONNXRuntime -->
        <dependency>
            <groupId>ai.djl.onnxruntime</groupId>
            <artifactId>onnxruntime-engine</artifactId>
            <version>${djl.version}</version>
        </dependency>

Then We’ll need to set up our environment and understand some key components. Let’s take a closer look at the code snippet below:


// Import necessary libraries
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.*;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.dataset.TimeFeaturizers;
import ai.djl.timeseries.distribution.DistributionLoss;
import ai.djl.timeseries.distribution.output.DistributionOutput;
// ... other necessary imports ...

public class MonthlyProductionForecast {
    // Constants and configurations
    
    final static String FREQ = "W";
    final static int PREDICTION_LENGTH = 4;
    final static LocalDateTime START_TIME = LocalDateTime.parse("2011-01-29T00:00");
    final static String MODEL_OUTPUT_DIR = "outputs";
     public static void main(String[] args) throws Exception {
        Logger.getAnonymousLogger().info("Starting...");        
        startTraining();
        final Map result = predict();
        for (Map.Entry entry : result.entrySet()) {
            Logger.getAnonymousLogger().info(String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue()));
        }
    }
}

You can access the full project of my sample code here: Github repository
The code above is the entry point of our timeseries forecasting application. It sets up configurations, loads data, and trains a DeepAR model using DJL. Let’s break down how this works:

  • We import necessary DJL libraries and define some constants like the frequency of the timeseries data, prediction length, and start time.

  • The main method starts the training process and connects to a GridDB database to seeds it with timeseries data. GridDB is a distributed, highly scalable NoSQL database that can efficiently store timeseries data.

  • We define various methods for prediction, training setup, and data loading, which we will explain in detail later.

Introduction to GridDB

GridDB is a powerful database system designed for storing and managing large volumes of timeseries data. Its high-speed data ingestion and query capabilities make it an excellent choice for timeseries forecasting applications.

Storing Time Series Data in GridDB

First, you need to add the maven dependencies to be able to use GridDB in your project:


    <dependency>
            <groupId>com.github.griddb</groupId>
            <artifactId>gridstore-jdbc</artifactId>
            <version>5.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.github</groupId>
            <artifactId>gridstore</artifactId>
            <version>5.3.0</version>
        </dependency>  

second, we need to populate the database with timeseries data to enable us do what we want to do. In the seedDatabase method of the GridDBDataset class, we populate the GridDB database with timeseries data. The data is read from two csv files and stored in two separate containers. Here’s the code for doing that:


  private static void seedDatabase() throws Exception {
          URL trainingData = Forecaster.class.getClassLoader().getResource("data/weekly_sales_train_validation.csv");
            URL validationData = Forecaster.class.getClassLoader().getResource("data/weekly_sales_train_evaluation.csv");
            String[] nextRecord;
            try ( GridStore store = GridDBDataset.connectToGridDB();  CSVReader csvReader = new CSVReader(new InputStreamReader(trainingData.openStream(), StandardCharsets.UTF_8));  CSVReader csvValidationReader = new CSVReader(new InputStreamReader(validationData.openStream(), StandardCharsets.UTF_8))) {
                store.dropContainer(TRAINING_COLLECTION_NAME);
                store.dropContainer(VALIDATION_COLLECTION_NAME);

                List columnInfoList = new ArrayList<>();

                nextRecord = csvReader.readNext();
                for (int i = 0; i < nextRecord.length; i++) {
                    ColumnInfo columnInfo = new ColumnInfo(nextRecord[i], GSType.STRING);
                    columnInfoList.add(columnInfo);
                }

                ContainerInfo containerInfo = new ContainerInfo();
                containerInfo.setColumnInfoList(columnInfoList);
                containerInfo.setName(TRAINING_COLLECTION_NAME);
                containerInfo.setType(ContainerType.COLLECTION);

                Container container = store.putContainer(TRAINING_COLLECTION_NAME, containerInfo, false);

                while ((nextRecord = csvReader.readNext()) != null) {
                    Row row = container.createRow();
                    for (int i = 0; i < nextRecord.length; i++) {
                        row.setString(i, nextRecord[i]);
                    }
                    container.put(row);
                }

                nextRecord = csvValidationReader.readNext();
                columnInfoList.clear();
                for (int i = 0; i < nextRecord.length; i++) {
                    ColumnInfo columnInfo = new ColumnInfo(nextRecord[i], GSType.STRING);
                    columnInfoList.add(columnInfo);
                }

                containerInfo = new ContainerInfo();
                containerInfo.setName(VALIDATION_COLLECTION_NAME);
                containerInfo.setColumnInfoList(columnInfoList);
                containerInfo.setType(ContainerType.COLLECTION);

                container = store.putContainer(VALIDATION_COLLECTION_NAME, containerInfo, false);
                while ((nextRecord = csvValidationReader.readNext()) != null) {
                    Row row = container.createRow();
                    for (int i = 0; i < nextRecord.length; i++) {
                        String cell = nextRecord[i];
                        row.setString(i, cell);
                    }
                    container.put(row);
                }
            }
    }

Integrating DJL with GridDB

DJL and GridDB work seamlessly together. We connect to GridDB to access our timeseries data and use DJL to build, train, and deploy our forecasting model. The GridDBDataset class provides the necessary functionality to interact with GridDB datasets.

We found it necessary to craft a custom implementation of DJL's TimeSeriesDataset. It stands as one of the most ingenious methods we've encountered for seamlessly integrating DJL with custom data repositories. Here's it:


...
public class GridDBDataset extends M5Forecast {

    ...

    public static GridStore connectToGridDB() throws GSException {
        Properties props = new Properties();
        props.setProperty("notificationMember", "127.0.0.1:10001");
        props.setProperty("clusterName", "defaultCluster");
        props.setProperty("user", "admin");
        props.setProperty("password", "admin");
        return GridStoreFactory.getInstance().getGridStore(props);
    }

    public static class GridDBBuilder extends M5Forecast.Builder {
    ...

        private File fetchDBDataAndSaveCSV(GridStore store) throws GSException, FileNotFoundException {
           File csvOutputFile = new File(this.getContainerName()+ ".csv");
            try ( GridStore store2 = store) {
                Container container = store2.getContainer(this.getContainerName());

                Query query = container.query("Select *");
                RowSet rowSet = query.fetch();

                int columnCount = rowSet.getSchema().getColumnCount();

                List csv = new LinkedList<>();
                StringBuilder builder = new StringBuilder();

                //Loan column headers
                ContainerInfo cInfo = rowSet.getSchema();
                for (int i = 0; i < cInfo.getColumnCount(); i++) {
                    ColumnInfo columnInfo = rowSet.getSchema().getColumnInfo(i);
                    builder.append(columnInfo.getName());
                    appendComma(builder, i, cInfo.getColumnCount());
                }
                csv.add(builder.toString());

                //Load each row
                while (rowSet.hasNext()) {
                    Row row = rowSet.next();
                    builder = new StringBuilder();
                    for (int i = 0; i < columnCount; i++) {
                        String val = row.getString(i);
                        builder.append(val);
                        appendComma(builder, i, columnCount);
                    }
                    csv.add(builder.toString());
                }
                try ( PrintWriter pw = new PrintWriter(csvOutputFile)) {
                    csv.stream()
                            .forEach(pw::println);
                }
            }
            return csvOutputFile;
        }

        public GridDBBuilder initData() throws GSException, FileNotFoundException {
            this.csvFile = fetchDBDataAndSaveCSV(this.store);
            return this;
        }

        @Override
        public GridDBDataset build() {
            GridDBDataset gridDBDataset = null;
            try {
                gridDBDataset = new GridDBDataset(this);
            } catch (GSException | FileNotFoundException ex) {
                Logger.getLogger(GridDBDataset.class.getName()).log(Level.SEVERE, null, ex);
            }
            return gridDBDataset;
        }

    }
}

Building an Advanced Time Series Forecasting Model

The core of our forecasting capability lies in the DeepAR model. We create, train, and evaluate the DeepAR model in the startTraining method. DJL's easy-to-use API makes it straightforward to define the model architecture and train it on our timeseries data.


private static void startTraining() throws IOException, TranslateException, Exception {

        DistributionOutput distributionOutput = new NegativeBinomialOutput();

        Model model = null;
        Trainer trainer = null;
        NDManager manager = null;
        try {
            manager = NDManager.newBaseManager();
            model = Model.newInstance("deepar");
            DeepARNetwork trainingNetwork = getDeepARModel(distributionOutput, true);
            model.setBlock(trainingNetwork);

            List trainingTransformation = trainingNetwork.createTrainingTransformation(manager);

            Dataset trainSet = getDataset(Dataset.Usage.TRAIN, trainingNetwork.getContextLength(), trainingTransformation);

            trainer = model.newTrainer(setupTrainingConfig(distributionOutput));
            trainer.setMetrics(new Metrics());

            int historyLength = trainingNetwork.getHistoryLength();
            Shape[] inputShapes = new Shape[9];
            // (N, num_cardinality)
            inputShapes[0] = new Shape(1, 1);
            // (N, num_real) if use_feat_stat_real else (N, 1)
            inputShapes[1] = new Shape(1, 1);
            // (N, history_length, num_time_feat + num_age_feat)
            inputShapes[2] = new Shape(1, historyLength, TimeFeature.timeFeaturesFromFreqStr(FREQ).size() + 1);
            inputShapes[3] = new Shape(1, historyLength);
            inputShapes[4] = new Shape(1, historyLength);
            inputShapes[5] = new Shape(1, historyLength);
            inputShapes[6] = new Shape(1, 1, TimeFeature.timeFeaturesFromFreqStr(FREQ).size() + 1);
            inputShapes[7] = new Shape(1, 1);
            inputShapes[8] = new Shape(1, 1);
            trainer.initialize(inputShapes);
            int epoch = 10;
            EasyTrain.fit(trainer, epoch, trainSet, null);
        } finally {
            if (trainer != null) {
                trainer.close();
            }
            if (model != null) {
                model.close();
            }
            if (manager != null) {
                manager.close();
            }
        }
    }

Now, let's break down each step of the startTraining method:

Step 1: We define the distribution output for the model. In this case, it's set to NegativeBinomialOutput. The distribution output specifies how the model should generate forecasts.

Step 2: We create the DeepAR training network using the getDeepARModel method. This network is responsible for defining the architecture of the DeepAR model. It's important to note that we pass true to indicate that this is a training network.

Step 3: We define training transformations for the dataset. These transformations will be applied to the input data to prepare it for training. They may include data normalization, feature engineering, and more.

Step 4: We prepare the training dataset using the getDataset method. This dataset will be used to train the DeepAR model. It includes historical data and target values for training.

Step 5: We create and configure the trainer for training the model. The setupTrainingConfig method sets up the training configuration, including loss functions, evaluators, and training listeners.

Step 6: We initialize the trainer with input shapes. This step ensures that the trainer knows the expected input shapes for the model. The inputShapes array contains the shapes for different input components of the model.

Step 7: Finally, we start training the model using the EasyTrain.fit method. We specify the number of training epochs, the training dataset (trainSet), and other optional parameters. The trainer will optimize the model's parameters to minimize the defined loss function and improve its performance on the training data.

Overall, the startTraining method sets up and trains a DeepAR model for timeseries forecasting by configuring the model, preparing the dataset, and initializing the trainer. This combination of steps ensures that the model is trained effectively to make accurate predictions based on historical timeseries data.

Prediction

After training, we can use the predict method to make forecasts based on the trained model. This method calculates various metrics, such as RMSSE (Root Mean Squared Scaled Error), MSE (Mean Squared Error), and quantile losses, to evaluate the model's performance.

Conclusion

In this article, we've explored how to forecast time series data using DJL and GridDB. We introduced the key concepts of timeseries data, DJL, and GridDB, and provided a detailed explanation of the code involved in building and training a DeepAR model for timeseries forecasting. By combining the power of deep learning with the efficiency of GridDB, you can unlock valuable insights from your timeseries data and make informed decisions for your business or research. DJL's ease of use and flexibility make it a valuable tool for any Java developer looking to tackle time series forecasting challenges.

In conclusion, the synergy between DJL and GridDB empowers you to harness the potential of time series data, providing accurate forecasts that can drive better decision-making in various domains. With the knowledge gained from this article, you are well-equipped to embark on your journey of time series forecasting using state-of-the-art deep learning techniques and robust database solutions.

The fusion of DJL and GridDB opens up new possibilities in the world of time series forecasting. As you delve deeper into this field, you'll discover the power of data-driven insights and how they can revolutionize industries ranging from finance to healthcare. Keep exploring,

keep learning, and keep forecasting the future with confidence.

If you have any questions about the blog, please create a Stack Overflow post here https://stackoverflow.com/questions/ask?tags=griddb .
Make sure that you use the “griddb” tag so our engineers can quickly reply to your questions.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.