==== Model configuration ====
After data processing, three fundamental aspects were treated in order to have a FL system that converges efficiently to meaningful solutions. Indeed, a correct model configuration plays a crucial role in FL as it encompasses the selection of the appropriate model architecture, optimization algorithm, and criterion.
===== Model architecture =====
The model architecture plays a crucial role in defining how the data flows
through the network, the number and type of layers, the connections between
neurons, and the activation functions applied at different stages.
Since the main purpose of this work is to compare two frameworks, the
choice of architecture was made so as to maximise as much as possible the
metrics chosen at the end of the training but limiting the resource demands
on the part of the model and thus also the time required to complete the
process.
====== Cloud environment ======
The model architecture chosen for the cloud part
prioritises simplicity and compatibility with the resource-constrained CPUs
of embedded devices and virtual machines. This decision was motivated by
the need for a lightweight and efficient model that can be easily deployed
and executed on various devices participating in the FL system.
For that reason, the architecture chosen for the cloud part of this work
is a model taken from the "Training a classifier" tutorial available on the
official PyTorch website [22] (Listing 4.2). This selection fits perfectly with
the request for a lightweight model that could still deliver satisfactory results.
The model of choice is a sequential one since they are often suitable and
widely used for classification problems [61]. It is a type of NN architecture
composed by a plain stack of layers where each layer has exactly one input
tensor and one output tensor. As shown in the figure, the model consists of
several layers, starting with two Convolutional layers (Conv2d) that process
the input data followed by MaxPooling layers (MaxPool2d) that reduce the
spatial dimensions. Afterwards, there are three fully connected layers (Lin-
ear) responsible for the final classification. The model has a total of 62,006
parameters, which are the trainable weights and biases within the layers.
These parameters are optimised during the training process to achieve accu-
rate predictions. The model architecture is suitable for tasks such as image
classification and demonstrates a balance between depth and complexity,
allowing for efficient training and satisfactory performance.
====== Local environment ======
A more complex model architecture was chosen for
the local environment to leverage the usage of an NVidia GPU for model
training. This decision was driven by the aim to harness the computational
power of the GPU and expedite the training process, ultimately leading to
improved model performance and more efficient training. This led to better
metrics results compared to the cloud counterpart.
In this case, the architecture that was chosen is indeed the ResNet-18
[5], as shown in Listing 4.3. ResNet-18 is a highly effective DL model for
image classification tasks, known for its ability to handle deeper architectures
without sacrificing performance. It performs exceptionally well with the
CIFAR-10 dataset and is computationally efficient.
In the context of this work, the choice was to use ResNet-18 in the
’weights=DEFAULT’ mode since it offers several advantages. The ’DE-
FAULT’ mode simplifies integration into the FL system and allows for faster
convergence and better generalisation through transfer learning with pre-
trained weights. In fact, this means that the model was initially trained on
a large dataset, typically for a different task (e.g., ImageNet classification),
and its learned weights and parameters were saved. Instead of training the
model from scratch on a new task, you start with these pre-trained weights
as an initialisation.
The layers can be summarised as follows:
1. Convolutional Layers (Conv2d): The initial layer of ResNet-18 consists
of a 2D convolution operation that convolves the input image with a set
of learnable filters. These filters help detect various low-level features,
such as edges and corners, in the input image.
2. Batch Normalization Layers (BatchNorm2d): After each convolutional
layer, batch normalization is applied, which normalizes the output of
the previous layer to improve training stability and accelerate conver-
gence.
3. Rectified Linear Unit Activation (ReLU): Following the batch nor-
malization, a non-linear activation function called ReLU is applied
element-wise to introduce non-linearity into the model and allow it
to learn complex features.
4. Max Pooling Layers (MaxPool2d): After the initial convolutional block,
max pooling layers are used to downsample the spatial dimensions of
the feature maps, reducing computational complexity while retaining
important information.
5. Basic Blocks: The ResNet-18 model utilizes a series of eight basic
blocks, of which for simplicity only four are visible visible in the code
above, each consisting of multiple convolutional and batch normaliza-
tion layers. The basic blocks are designed to mitigate the vanishing
gradient problem and allow the model to be deeper without perfor-
mance degradation.
6. Adaptive Average Pooling (AdaptiveAvgPool2d): The adaptive average
pooling layer aggregates the spatial dimensions of the feature maps into
a fixed size, ensuring the model can handle input images of varying sizes
and aspect ratios.
7. Fully Connected Layers (Linear): Towards the end of the model, adap-
tive average pooling is applied to convert the spatial dimensions of the
feature maps into a fixed size. Subsequently, fully connected layers are
used to perform classification based on the learned features.
The Resnet-18 model used was also modified with custom fully connected
layers (Linear) in order to accommodate a different output classification task.
In the original ResNet-18, the model was designed for image classification
with 1,000 output classes. However, in this modified version, the number of
output classes was reduced to just 10 to align with the specific classification
problem at hand. By reducing the number of output classes, the model’s
architecture becomes more tailored to the target classification task, which,
in turn, reduces the number of trainable parameters by about half million
parameters.
===== Optimizer =====
The optimiser chosen, presented in the introduction paragraph, is the SGD
since it is a popular optimiser for classification problems due to its simplicity
and effectiveness. It was tuned using some hyperparameters: momentum and
learning rate.
Momentum is a technique used to accelerate the convergence of the op-
timisation process and improve its stability. It addresses the issue of slow
convergence and oscillations in the loss function by introducing a "velocity"
term that helps the optimiser navigate the optimisation landscape more ef-
ficiently. The value of 0.9 that was chosen, meaning that the optimiser gives
more weight to the past accumulated gradients, leading to smoother updates.
Learning rate is a hyperparameter that determines the step size at which
the model updates its weights during the optimisation process. It controls
how much the model adjusts its internal parameters in response to the error
calculated during training. In this case the chosen learning rate of 0.001
strikes a balance between making smaller, more precise steps towards the
optimal solution and avoiding overshooting or oscillations during the opti-
misation process.
===== Criterion =====
Cross-entropy is commonly used in classification problems because it quan-
tifies the difference between the predicted probabilities and the actual target
labels, providing a measure of how well the model is performing in classifying
the input data.
This can be expressed using this formula:
where (y) is the target probability, (p) is the predicted probability, and
(m) is the number of classes. So that is how “wrong” or “far away” the
prediction is from the true distribution.
In the context of CIFAR-10, where there are ten classes (e.g., airplanes,
cars, birds, etc.), the Cross-Entropy loss compares the predicted class proba-
bilities with the true one-hot encoded labels for each input sample. It applies
the logarithm to the probabilities and then sums up the negative log like-
lihoods across all classes. The objective is to minimize this loss function
during the training process, which effectively encourages the model to as-
sign high probabilities to the correct class labels and low probabilities to the
incorrect ones.
One of the reasons why Cross-Entropy Loss is considered suitable for
CIFAR-10 and classification tasks, in general, is its ability to handle multi-
class scenarios efficiently. By transforming the model’s output into probabil-
ities through the softmax activation, it inherently captures the relationships
between different classes, allowing for a more expressive representation of
class likelihoods.
==== Client-side settings ====
On the client side, three important tasks including training, validation and
testing are being performed. Each task comes to be executed by every client
participating in the FL infrastructure. At the end of a cycle, tasks come to be
blocked temporarily so that the results accumulated up to that point are sent
to the server, which will take care of aggregating them. Once aggregated,
each client will start again with the tasks assigned to it from the data formed
by the server.
The model is trained for 3 epochs in the case of the cloud environment
and for 4 epochs in the case of the local environment. The total steps per
epoch in both cases is for each task of 1250 steps.
==== Aggregation algorithm ====
In this FL scenario, the Federated Averaging (FedAvg) [10] algorithm was
employed as the aggregation method. FedAvg is a fundamental and widely
adopted algorithm used to aggregate model updates from multiple clients
(or participants) in a FL setting.
The primary objective of FedAvg is to allow collaborative model training
while preserving data privacy. After local training, clients communicate their
model updates (gradients) to the server, where these updates are aggregated
to create a global model. The global model is then sent back to the clients
that, use it as the starting point for the next round of training. This iterative
process continues until the global model converges to a satisfactory solution.
The Listing 4.4 describes how FedAvg works in a FL infrastructure:
As reader can see, the FedAvg algorithm works by averaging the model
updates from individual clients, weighted by the proportion of data samples
each client holds. This weighted average ensures that clients with larger
datasets have a more significant influence on the global model, while main-
taining fairness for clients with smaller datasets. By iteratively aggregating
the updates and distributing the global model back to clients, FedAvg en-
ables collaborative learning without sharing raw data.
==== Metrics ====
In order to make a good comparison, three of the most common and essential
metrics were chosen to evaluate model performance and effectiveness.
The chosen metrics are the following:
• Loss: The loss function quantifies the dissimilarity between the pre-
dicted output of the model and the actual ground truth labels in the
training data. It provides a measure of how well the model is perform-
ing during training. The goal is to minimize the loss function, as a
lower loss indicates that the model is better aligned with the training
data.
• Accuracy: Accuracy is a fundamental metric used to assess the model’s
overall performance. It represents the proportion of correctly predicted
samples to the total number of samples in the dataset. A higher ac-
curacy indicates that the model is making accurate predictions, while
a lower accuracy suggests that the model might need further improve-
ments. Calculating the accuracy of individual clients in a FL classifi-
cation problem is important to assess the performance of each client’s
local model. This helps in understanding how well each client is adapt-
ing to its local data distribution and making accurate predictions.
The formula for the accuracy can be simply expressed as follows:
It ranges between 0 and 1, where 1 indicates perfect accuracy, meaning
correctly predicted sample, and 0 indicates it will always fail in the
prediction.
F1-score: The F1-score is a metric that combines both precision and
recall to provide a balanced evaluation of the model’s performance,
especially when dealing with imbalanced datasets. Precision measures
the ratio of correctly predicted positive samples to all predicted positive
samples, while recall measures the ratio of correctly predicted positive
samples to all actual positive samples. The F1-score is the harmonic
mean of precision and recall, providing a single metric that considers
both aspects.
==== Server-side settings ====
After the choice of the metrics to evaluate, the last thing to decide were the
server settings. In fact, two important parameters are missing in this regard:
the number of rounds and the number of clients that would have participated
in the FL infrastructure.
A round represents a communication cycle between clients and the cen-
tral server in the FL training process. During each round, participating
clients perform local training using their available local data. Subsequently,
the updated model weights trained locally are sent to the central server or
coordination node. Here, the weights are centrally aggregated to obtain an
updated global model, which represents the combined knowledge of all par-
ticipating clients. At this point, the round is concluded and the aggregate
model is sent back to the clients, who will use this updated model to perform
a new round.
In this case, the number of rounds chosen for the cloud part is equal to
4 meanwhile for the local part there is a total of 10 rounds.
On the server-side, the second parameter to be chosen is the number of
clients that will participate in the various rounds of FL. As also seen in the
subsection 4.1.1, in this case, the number of cloud-side clients is 2 while on
the local-side is equal to 4.
=== Results ===
|}
== Deep investigation of Applying NVFlare to a real-word case ==
TBD