Training a Linear Regression model in Kotlin with TensorFlow Java API
My story began a few days ago, when I realized there were no examples on how to train a Linear Regression model on TensorFlow using Java API.
Why is it so important to me? Who would use TensorFlow to find the best weight and bias in a Linear Problem? Who would try to do it in Java? This may sound crazy like using a hammer to eat the noodles, isn’t it?
Let me explain: I’m a professional ML/DL framework designer and my main area of expertise is Java and other JVM languages like Kotlin.
I need a good Java API for TF for many reasons:
- it could be used in integration with Apache Ignite ML or Apache Spark ML
- it could be used in many JVM back-ends as a mature Deep Learning platform (PyTorch announced an experimental Java API but it’s in very early stages)
- it is difficult to write the same library from scratch, it’s easier to use an existing library with a wrapper
Unfortunately, the TensorFlow Developer’s community is completely focused on Python API and doesn’t pay enough attention to the JVM — C bridge.
I hope that my example and tutorials help people get involved in Java API development and it becomes more active as a result.
Currently, there is no special Kotlin API for TensorFlow, but Java examples could be converted to Kotlin with IntelliJ IDEA’s Java-to-Kotlin converter so you can use them with Kotlin back-end.
I could use Java API for TensorFlow 1.15* to train the simple model.
Let’s start from data preparation.
For this example, I’ll take ten x values in xValues array and the Y data is created approximately with the following formula: 10*x + 2 + noise.
// Prepare the X data
val xValues = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f)// Prepare the Y data.
val yValues = floatArrayOf(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f)for ((i, x) in xValues.withIndex()) {
yValues[i] = 10 * x
+ 2
+ Random(42).nextDouble(-0.1, 0.1).toFloat()
}
In TensorFlow you need to describe a static Graph of operations to make forward and backward propagation. Java API 1.15 has more than 800+ operands to describe different complex models, including Convolutional Networks, LSTM and others.
But I will describe the simplest graph with around 20 operands to run the calculations.
Let’s start by building foundation for the Graph:
val X = tf.placeholder(Float::class.javaObjectType,
Placeholder.shape(Shape.scalar()))val Y = tf.placeholder(Float::class.javaObjectType,
Placeholder.shape(Shape.scalar()))
Placeholders are two building blocks to put x and y values from the proposed arrays xValues and yValues.
To bind x and y values in a linear model we need to define two variables: weight and bias. Variables are other popular building blocks in TensorFlow Graph building.
val weight: Variable<Float> = tf.variable(Shape.scalar(),
Float::class.javaObjectType)val bias: Variable<Float> = tf.variable(Shape.scalar(),
Float::class.javaObjectType)
All variables should be initialized, with the constants in this case:
val weightInit = tf.assign(weight, tf.constant(1f))val biasInit = tf.assign(bias, tf.constant(1f))
Here’s the formula for Linear Regression that binds weight, bias, x and y together:
y = weight * x + bias (or y = wx + b)
val mul = tf.math.mul(X, weight)val yPredicted = tf.math.add(mul, bias)
So, yPredicted is the real model (or function) that could be easily reused for prediction purposes later.
Each optimization task (ML and DL are optimizations tasks too) should have a function to optimize it for. In this case, I’m using MSE as my loss function.
val sum = tf.math.pow(tf.math.sub(yPredicted, Y), tf.constant(2f))val mse = tf.math.div(sum, tf.constant(2f * n))
At the heart of deep learning lies the gradient descent algorithm and the gradient calculation.
I won’t go into backward propagation computation details here, but I should mention that Java API for TF doesn’t support the optimizers and you can’t choose Adam or Adadelta directly as you would in Python API.
val gradients = tf.gradients(mse, listOf(weight, bias))
val alpha = tf.constant(0.2f)
val weightGradientDescent =
tf.train.applyGradientDescent(weight, alpha,
gradients.dy<Float>(0))
val biasGradientDescent =
tf.train.applyGradientDescent(bias, alpha,
gradients.dy<Float>(1))
And now it’s ready for training!
First of all, I need to initialize the variables inside the session:
Session(graph).use { session ->
session.runner()
.addTarget(weightInit)
.addTarget(biasInit)
.run()
Second, let’s define a pair of Tensors for each <x,y> pair to put them into the Graph and change our weight and bias (Neural Network parameters) by a small delta:
for ((cnt, x) in xValues.withIndex()) {
val y = yValues[cnt] Tensor.create(x).use { xTensor ->
Tensor.create(y).use { yTensor -> session.runner()
.addTarget(weightGradientDescent)
.addTarget(biasGradientDescent)
.feed(X.asOutput(), xTensor)
.feed(Y.asOutput(), yTensor)
.run() println("$x $y")
}
}
}
Real data should be put in the .feed() method instead of the placeholders defined earlier.
There are no epochs or batches in this simple training process, because we have only 10 observations and they could go to the local minimum example by example during gradient descent.
So, the model is trained. The weight and bias are found and are kept inside the session. Let’s get them from TensorFlow.
// Extract the weight value
val weightValue = session.runner()
.fetch("Variable")
.run()[0].floatValue()
println("Weight is $weightValue")
// Extract the bias value
val biasValue = session.runner()
.fetch("Variable_1")
.run()[0].floatValue()
println("Bias is $biasValue")
Unfortunately, you have to know the real names of variables inside the TensorFlow session if you want to reach them. Names Variable and Variable_1 are generated by default.
The last but not the least with all these ML models is the ability to make predictions.
Let’s predict regression value for the known x = 10f:
val x = 10fvar predictedY = 0fTensor.create(x).use { xTensor ->
Tensor.create(NO_MEANING_VALUE).use { yTensor ->
predictedY = session.runner()
.feed(X.asOutput(), xTensor)
.feed(Y.asOutput(), yTensor)
.fetch(yPredicted)
.run()[0].floatValue()
}
}
println("Predicted value: $predictedY")
You could fill the second Y tensor with any value to run the session to make the prediction.
After running this code I’ve got the next output:
Weight is 9.866241
Bias is 3.364536
Predicted value: 102.02695
You can find this and other deep learning code examples with Kotlin on GitHub.
The Graph building and training are not easy on the TensorFlow Java API 1.15.
Hope that my tutorial will be helpful for doing Deep Learning in you production.
*- the Java API for the TF 2.x not released yet, but the SIG JVM is working on it.