ANNs - Build and train model using MNIST, Kotlin and DeepLearning4J

March 24, 2019

Introduction

Using Kotlin and DeepLearning4J, we’ll discover today how easy it is to build and train our own neural network and visualize the results of its training.

For this purpose, we’ll make use of the MNIST dataset, which is basically the hello-world for image recognition using neural networks.

At the end of this tutorial, we’ll have been build a project able to build and train a mode on MNIST.

Access the GitHub repository here.

Creation of the project

We’ll use IntelliJ and Maven to generate a Kotlin project from a default template (also named archetype) named kotlin-archetype-jvm.

To do this, open IntelliJ and perform the following actions

  1. Go to File > New > Project…
  2. On the left ribbon, select Maven
  3. On the right container, check Create from archetype and select kotlin-archetype-jvm in the list of archetypes available

IntelliJ’s new project creation pop-up IntelliJ’s new project creation pop-up

  1. Click on Next
  2. Enter a group id for your Maven project (e.g com.company) and an artifact id (e.g kotlin-dl4j-medium)
  3. Click on Next twice
  4. Give a name to your project (it can be the same as your artifact id to be consistent)
  5. Click on Finish

You should now have a new project created with the correct structure to begin working

Adding the needed dependencies to our project

We now have an empty project, containing the dependencies needed to run a basic Kotlin application. We’ll now see how to add the dependencies needed inside our pom.xml file to build and train our model and also have visualization setup.

Let’s first begin with the basic one : the core library for DeepLearning4J. It will allow us to work with all the objects needed for building and training the network. ND4J is also required for this.

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

To be able to work with the MNIST dataset, we’ll need the following dependency to be added to our pom.xml file.

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-datasets</artifactId>
    <version>RELEASE</version>
</dependency>

Finally, to switch on the visualization, we’ll need the following dependency.

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-ui_2.10</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

Now run the following to install the dependencies.

$ mvn install

You just be getting the following output from the console meaning everything went well.

[INFO] ---------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ---------------------------------------------
[INFO] Total time: 11.790 s
[INFO] Finished at: 2019-03-24T16:25:51+01:00
[INFO] Final Memory: 52M/619M
[INFO] ---------------------------------------------

Process finished with exit code 0

Write the code

If the archetype did its job right, you should have a file named Hello.kt inside the source folder of your project.

Rename it to something more self-speaking (e.g. MNISTTrainer).

Inside it, you have a minimalistic piece of code achieving the task of printing Hello World ! to the console. Remove it to keep only the following part.

fun main(args: Array<String>) {

}

We’ll start by adding the piece of code needed for visualization. It will basically set up a local web application accessible from http://localhost:9000.

fun main(args: Array<String>) {
    val uiServer = UIServer.getInstance()
    val statsStorage = InMemoryStatsStorage()
    uiServer.attach(statsStorage)
}

We now have our setup for launching a webapp, if you click on Run, you should be able to access the web application at http://localhost:9000 and the following screen.

Now that we have our visualization setup, we’ll continue by defining our constants. I’ll just replace the existing code by a comment from now on, to only focus on the newly added code.

fun main(args: Array<String>) {
    // Set up of visualization
    val numRows = 28
    val numColumns = 28
    val pixelCount = numRows * numColumns   
    val outputNum = 10
    val batchSize = 128
    val rngSeed = (0..100).random()
    val numEpochs = 5
}

We can now just instantiate our Iterators. These objects will basically allow us to go through our datasets. We’ll be building two of them : one for training and the other for testing. For this, we’ll be using the class provided by DeepLearning4J named MnistDataSetIterator.

The Boolean parameter of the structure defines whether the set is used for training or not.

fun main(args: Array<String>) {
    // Set up visualization
    // Define constants
    val mnistTrain = MnistDataSetIterator(batchSize, true, rngSeed)
    val mnistTest = MnistDataSetIterator(batchSize, false, rngSeed)
}

Let’s now build our model.

We’ll for that use the most basic set up possible to avoid unnecessary complexity.

The setup will include the following details :

  1. Two layers, a dense one and an output one.
  2. ReLU as activation function for the dense layer and SoftMax for the output layer.
  3. Adam as optimizer.
fun main(args: Array<String>) {
    // Set up visualization
    // Define constants
    // Split dataset in Train and Test
    val multiLayerConfiguration = NeuralNetConfiguration.Builder()
        .seed(rngSeed.toLong())
        .updater(Adam())
        .list()
        .layer(
        DenseLayer.Builder()
            .nIn(pixelCount)
            .nOut(1000)
            .activation(Activation.RELU)
            .build()
        )
        .layer(
        OutputLayer.Builder() //create hidden layer
            .nOut(outputNum)
            .activation(Activation.SOFTMAX)
            .build()
        )
        .build()

    val model = MultiLayerNetwork(multiLayerConfiguration)
    model.init()
    model.setListeners(StatsListener(statsStorage))
}

You see that we also add a Listener to our model. This listener is used to link the visualization to our model. We defined it in this specific section above.

To train it, nothing easier than looping for a number of epochs, and executing the method fit of our model while giving him the dataset.

fun main(args: Array<String>) {
    // Set up visualization
    // Define constants
    // Split dataset in Train and Test
    // Configure and init the model
    for (i in 0 *until *numEpochs) {
        model.fit(mnistTrain)
    }
}

And to test (evaluate) it, use the following piece of code.

fun main(args: Array<String>) {
    // Set up visualization
    // Define constants
    // Split dataset in Train and Test
    // Configure and init the model
    // Train the model

    val eval = Evaluation(outputNum) //create an evaluation object with 10 possible classes
    while (mnistTest.hasNext()) {
        val next = mnistTest.next()
        val output = model.output(next.*features*) //get the networks prediction
        eval.eval(next.*labels*, output) //check the prediction against the true class
    }

*    println*(eval.stats())
}

You can see I added a print at the end, this will show you something like the following after the evaluation, proving how well your model has performed and showing you its weaknesses.

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0,9772
 Precision:       0,9771
 Recall:          0,9771
 F1 Score:        0,9770
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)

=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  970    0    1    0    1    2    3    1    2    0 | 0 = 0
    0 1126    1    1    0    1    3    0    3    0 | 1 = 1
    3    1 1010    0    1    0    3    5    8    1 | 2 = 2
    1    2    9  964    0   15    0    4    7    8 | 3 = 3
    2    1    3    0  958    0    3    1    0   14 | 4 = 4
    3    0    0    2    2  871    8    0    5    1 | 5 = 5
    3    3    0    1    3    2  946    0    0    0 | 6 = 6
    1    6    9    1    2    0    0 1000    2    7 | 7 = 7
    8    0    1    0    6    3    7    3  940    6 | 8 = 8
    1    3    0    2    5    2    2    5    2  987 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

Click on Run now (the play button on top of IntelliJ’s UI)

If you open http://localhost:9000, you would also have the possibility to see your model’s statistics of training and testing.

Screenshot of the DeepLearning4J UI Screenshot of the DeepLearning4J UI

Conclusion

It is really that easy to integrate the DeepLearning4J library into a Kotlin app and generate a deep learning model.

For the next time, I’ll try to go further and integrate my model into an existing app and make use of it.

GitHub repository : https://github.com/YassinHajaj/dl4j-kotlin