No results found

Your search did not match any results.

Federated Deep Learning using Java on the Client and in the Cloud

Deep Learning algorithms are a subclass of general machine learning algorithms. One of the core ideas of deep learning is that it has some similarities with how the human brain works. Similar to how layers of neurons in the brain process information, deep learning software contains allows developers to create a network containing multiple layers of neurons that process information as well.

Especially when large amounts of data are available, Deep Learning can provide high-quality results. For example, deep learning software can be used to classify images, and detect object on those images. Before a deep learning algorithm can tell what objects are on an image, it has to be trained with lots of data. Sometimes there is a wide amount of high-quality data available for training the network, but in many real-world situations this is not the case. Real-world data is often obtained via consumer devices (e.g. pictures taken on a mobile device) and can not easily be shared or transferred, due to privacy restrictions or regulations.

In this scenario, where there is a large amount of data but it can’t be sent to a server, Federated Deep Learning comes to the rescue.

Using Deep Learning software to make predictions

We start with the straightforward case where, based on an image, we want the Neural Network (or model) to tell us what is shown in the image. In order to do so, the image is converted into an array of numbers, where each pixel contains a number that corresponds to the grayscale of the picture. Note that more complex models are possible, taking into account colors etc, but we want to start with a simple example. The result of this conversion, the array of numbers, is sent through the Neural Network, and when the network is well trained, it will recognize what is in the image.

Figure 1

Under the hood, the Neural Network consists of a number of layers: an input layer containing the original data (the array of numbers), an output layer containing the results (in this case the most likely label being a bike) and one or more hidden layers. Data propagates through the network, using network parameters like weights, biases and activation functions, and results in resulting data. The following diagram, taken from https://deeplearning4j.org, illustrates the process:

Figure 2

A course in deep learning is out of scope for this article, the interested reader is referred to
https://deeplearning4j.org/docs/latest/deeplearning4j-beginners.

Before a Neural Network can be used, it should be trained. This is done by applying lots of data to the network, and tell the network when it is wrong. Based on the feedback, the deep learning algorithm modifies the parameters of the network (weights and biases), in such a way that it is more likely that the network will now make a good prediction. In many cases, labeled trainingdata is used to learn the neural network. When an image of a mountain is supplied to the Neural Network, and the network classifies it as a “bike”, but the associated label was “mountain”, the deep learning software will use sophisticated techniques to modify the neural network.

Figure 3

Under the hood, this comes down to a new set of parameters, as shown below:

Figure 4

Internally, Deep Learning software typically involves high-performance linear algebra. While predictions are relatively easy and straightforward, training is more complex, and requires more computing power. Also, training typically requires lots of high-quality data.

There are many software libraries providing Deep Learning API’s, in different languages. One of the leading Java libraries is deeplearning4j (https://deeplearning4j.org), created and maintained by SkyMind (https://skymind.ai) While the top-level API for deeplearning4j is pure Java, the implementations of the functionality offered by the API’s are leveraging native code, including GPU-specific optimisations.

Training a neural network on the server

In a first setup, we will have a client that sends raw data to a server containing the Neural Network, and ask it for a result. Before this can be done, the server needs a trained network. We can train a neural network using the deeplearning4j apis, for example



MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()             .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
               .updater(new Nesterovs(learningRate, 0.9))
               .list()
               .layer(0, new DenseLayer.Builder()
                       .nIn(numInputs)
                       .nOut(numHiddenNodes)
                       .weightInit(WeightInit.XAVIER)
                       .activation(Activation.RELU)
                       .build())
               .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                       .nIn(numHiddenNodes)
                       .nOut(numOutputs)
                       .weightInit(WeightInit.XAVIER)
                       .activation(Activation.SOFTMAX)
                       .build())
               .backprop(true)
               .build();


 

In this code snippet, we create the configuration for the neural network. Again, the details about the deeplearning4j library are out of scope for this article, and the reader is referred to https://deeplearning4j.org for more information. The configuration contains 2 layers (the input data is often not considered to be a layer as it is not configurable). The first layer is a hidden layer, and the second layer is the output layer. The output of the first layer is the input of the second layer.

When the network is configured, it can be trained. The following code snippet will feed the network with labeled data.



RecordReader rrTrain = new CSVRecordReader();
File trainsrc = new File("linear_data_train.csv");
rrTrain.initialize(new FileSplit(trainsrc));
DataSetIterator iterTrain = new RecordReaderDataSetIterator(rrTrain, batchSize, 0, 2);
network.fit(iterTrain);

In this case, the trained data is in a CSV file, and it contains both input data as well as the expected result. The network.fit(iterTrain) call will cause the network to be trained. A number of iterations will be executed. In each iteration, a forward loop is applied and all input data is used to make a prediction. Based on the predicted result and the actual result (contained in the csv), the network is modified. An error score tells how “wrong” the network is, and the goal is that with each iteration, the network is less wrong than before.

Once the network is trained, we can make requests to predict results based on new data. A conventional way of doing is by having the deep learning software with the neural network on a server, and a client sending raw image data via a REST interface to the neural network. This approach is shown in the picture below.

Figure 5

Typically, the neural network responds with an array containing possible labels, and their probability (e.g. there is 90% chance this image contains a bike, 5% chance it contains a car). Additional processing can be done on this response, before it is sent back to the client. In general, the REST interface might call something like this:



String predicted = doSomeProcessing(network.output(data));
		

An advantage of this approach is that the Neural Network can be enhanced. In case the network returns a wrong result (e.g. “car”), the client might send a correction to the server (“bike”). Based on this information, the neural network can be retrained and become better, as shown in the picture below:

Figure 6

There are a number of drawbacks to this approach as well. Since client-server communication is needed, there will be no real-time response. If the client needs an immediate response, the time to setup an HTTP call alone will be too much. Also, the client device might not always be connected to the Internet, while the application assumes that the network should always be able to return predictions. And increasingly important, this approach requires sending raw image data to a server. Due to privacy restrictions and regulations, this might not always be allowed.

Predictions on the client

A number of these drawbacks can be removed by doing predictions on the client. If we run the same code on the client as on the server, the client can use a simple API to query the neural network.

Figure 7

The Java code for making predictions on the client is exactly the same as the code for making predictions on the server.

If you want to run deeplearning4j Java code on mobile clients, you can use the Gluon Mobile framework which allows to create Java apps that work on iOS and Android devices. An explanation on how to do this is outside the scope of this article, and the interested reader is referred to http://docs.gluonhq.com/getting-started/ A sample project showing deep learning on mobile devices using Java code only is shown here:
https://github.com/gluonhq/gluon-samples/tree/master/deeplearning-linearclassifier.

The deeplearning4j software contains import and export functionality for neural networks. Hence, it is possible to train the network on a server, export it to a file, send that file to the client, and import the file as a neural network on the client. In general, predictions are not very computation intensive. Moreover, deeplearning4j leverages native implementations available on mobile devices, and the performance can be really great.

This new approach now allows the client to get a result in real-time, and it also works fine when the device is not connected to the Internet. Also, the data (e.g. a picture taken with a mobile phone) stays at the client, so there are no privacy concerns. In order to get a meaningful result based on the data, there is no need to send this data from the device to a server.

There is a drawback to this approach though, since the model is not improving anymore. In the previous setup, the client could correct the neural network. This feedback loop is required in order to make neural networks really useful in real-world applications. Deep learning works better if more data is available, and if the neural network is constantly updated.

Fortunately, we can achieve this goal without giving up the other benefits of local predictions.

Training on the client

Since the training code on the server is written in Java, it works on the clients as well. When the neural network predicts a wrong result, and the user wants to correct it, the network can be retrained.

Figure 8

Since training is more resource-intensive than simply making predictions, it is recommended to do this during low-activity time, e.g. at night, and preferable when the device is charged or charging.

This approach now has a number of benefits, including the learning capability of the neural network. In order to become really useful though, it would be interested to have a large number of clients jointly train a neural network without sending raw data. This is done in the next setup.

Federated Learning

When a neural network on a specific client device is retrained, that client can send the new parameters of its neural network to a server. The parameters of the neural network do not contain the raw image data that is used to train the neural network. When many devices send the new parameters to a server, this server can combine them and create a new “best” neural network, that is occasionally sent back to clients. There are a number of algorithms in how to do this, and again, the details are out of scope for this article.

The general idea though is shown below:

Figure 9

Each client retrains the Neural Network with its own data, and the modifications in the Neural Network are sent to a server. This approach allows for an enhancement of the Neural Network while respecting users privacy. Moreover, the computing power of a large amount of client devices is used to improve the Neural Network.

Training a neural network is not trivial. Thanks to the deeplearning4j API’s, and the availability of Java on client devices, Java provides a great platform for enabling Federated Learning.