Keywords

1 Introduction

In many real-world scenarios, data is distributed over organizations or devices and is difficult to centralize. Due to legal reasons, data might have to remain and be processed where it is generated, and in many cases may not be allowed to be transferred [10]. Furthermore, due to communication limitations it can be practically impossible to send data to a central point of processing. In many applications of Machine Learning (ML) these challenges are becoming increasingly important to address. For example, sensors, cars, radio base stations and mobile devices are capable of generating more relevant training data than can be practically communicated to the cloud [8] and datasets in healthcare and industry cannot legally be moved between hospitals or countries of origin.

Fig. 1.
figure 1

Our approach adjusts to non-Independent and Identically Distributed (IID) data distributions by adaptively training a Mixture of Experts (MoE) for clients that share similar data distributions.

Federated Learning (FL) [1, 27] shows promise to leveraging data that cannot easily be centralized. It has the potential to utilize compute and storage resources of clients to scale towards large, decentralized datasets while enhancing privacy. However, current approaches fall short when data is heterogeneous as well as non-Independent and Identically Distributed (non-IID), where stark differences between clients and groups of clients can be found. Therefore, personalization of collectively learned models will in practice often be critical to adapt to differences between regions, organizations and individuals to achieve the required performance [11, 18]. This is the problem we address in this chapter.

Our approach adjusts to non-IID data distributions by adaptively training a Mixture of Experts (MoE) for clients that share similar data distributions.We explore a wide spectrum of data distribution settings: ranging from the same distribution for all clients, all the way to different distributions for each client. Our aim is an end-to-end framework that performs comparable or better than vanilla FL and is robust in all of these settings.

In order to achieve personalization, the authors of [11] introduce a method for training cluster models using FL. We show that their solution does not perform well in our settings, where only one or a few of the cluster models converge. To solve this, inspired by the Multi-Armed Bandit (MAB) field, we employ an efficient and effective way of balancing exploration and exploitation of these cluster models. As proposed by the authors of [24, 29], we add a local model and use a MoE that learns to weigh, and make use of, all of the available models to produce a better personalized inference, see Fig. 1.

In summary, our main contributions are:

  1. 1.

    We devise an FL algorithm which improve upon [11] by balancing exploration and exploitation to produce better adapted cluster models, see Sect. 3.1;

  2. 2.

    We use said cluster models as expert models in an MoE to improve performance, described in Sect. 3.1;

  3. 3.

    An extensive analysisFootnote 1 of our approach with respect to different non-IID aspects that also considers the distribution of client performance, see Sect. 4.5.

2 Background

2.1 Problem Formulation

Consider a distributed and decentralized ML setting with clients \({k} \in \left\{ 1, 2, \ldots ,\right. \left. {K}\right\} \). Each client k has access to a local data partition \({P^k}\) that never leaves the client where \({n_k =\vert {P^k}\vert }\) is the number of local data samples. In this chapter we are considering a multi-class classification problem where we have \({{n} = \sum _{{k}=1}^{{K}} n_k}\) data samples \({\boldsymbol{x}_i}\), indexed \({{i} \in \left\{ 1, 2, \ldots , {n_k}\right\} }\), and output class label \({y_i}\) is in a finite set. We further divide each client partition \({P^k}\) into local training and test sets. We are interested in performance on the local test set in a non-IID setting, see Sect. 2.2.

2.2 Regimes of Non-IID Data

In any decentralized setting it is common to have non-IID data that can be of non-identical client distributions [14, 18], and which can be characterized as:

  • Feature distribution skew (covariate-shift). The feature distributions vary between clients. Marginal distributions \({\mathcal {P}}\left( \boldsymbol{x}\right) \) varies, but \({\mathcal {P}}\left( {y} \,\vert \,\boldsymbol{x}\right) \) is shared;

  • Label distribution skew (prior probability shift, or class imbalance). The distribution of class labels are different between clients, so that \({\mathcal {P}}\left( {y}\right) \) varies but \({\mathcal {P}}\left( \boldsymbol{x} \,\vert \,{y} \right) \) is shared;

  • Same label, different features. The conditional distributions \({\mathcal {P}}\left( \boldsymbol{x} \,\vert \,{y}\right) \) varies between clients but \({\mathcal {P}}\left( {y}\right) \) is shared;

  • Same features, different label (concept shift). The conditional distribution \({\mathcal {P}}\left( {y} \,\vert \,\boldsymbol{x}\right) \) varies between clients, but \({\mathcal {P}}\left( \boldsymbol{x}\right) \) is shared;

  • Quantity skew (unbalancedness). Clients have different amounts of data.

Furthermore, the data independence between clients and between data samples within a client can also be violated.

2.3 Federated Learning

In a centralized ML solution data that may be potentially privacy-sensitive is collected to a central location. One way of improving privacy is to use a collaborative ML algorithm such as Federated Averaging (FedAvg) [27]. In FedAvg training of a global model \({f_g}(\boldsymbol{x}, {\boldsymbol{w}_{g}})\) is distributed, decentralized and synchronous. A parameter server coordinates training on many clients over several communication rounds until convergence.

In communication round t, the parameter server selects a fraction C out of K clients as the set \({S_t}\). Each selected client \({{k}\in {S_t}}\) will train locally on \({n_k}\) data samples \({({\boldsymbol{x}_i}, {y_i}), {i} \in {P^k}}\), for E epochs before an update is sent to the parameter server. The parameter server performs aggregation of all received updates and updates the global model parameters \({\boldsymbol{w}_{g}}\). Finally, the new global model parameters are distributed to all clients.

We can now define our objective as

(1)

where \({{l}\left( {\boldsymbol{x}_i}, {y_i}, {\boldsymbol{w}_{g}}\right) }\) is the loss for \({{y_i}, {\hat{y}_g} = {f_g}\left( {\boldsymbol{x}_i}, {\boldsymbol{w}_{g}}\right) }\). In other words, we aim to minimize the average loss of the global model over all clients in the population.

2.4 Iterative Federated Clustering

In many real distributed use-cases, data is naturally non-IID and clients form clusters of similar clients. A possible improvement over FedAvg is to introduce cluster models that map to these clusters, but the problem of identifying clients that belong to these clusters remains. We aim to find clusters, subsets of the population of clients, that benefit more from training together within the subset, as opposed to training with the entire population.

Using Iterative Federated Clustering Algorithm (IFCA) [11] we set the expected largest number of clusters to be J and initialize one cluster model with weights \({\boldsymbol{w}^j_g}\) per cluster \({j \in \left\{ 1,2,\ldots ,{J}\right\} }\). At communication round t each selected client k performs a cluster identity estimation, where it selects the cluster model \({\hat{j}}^k\) that has the lowest estimated loss on the local training set. This is similar to [26].

The cluster model parameters \({\boldsymbol{w}^j_g}\) at time \(t+1\) are then updated by using only updates from clients the jth selected cluster model, so that (using model averaging [11, 27])

$$\begin{aligned} {n_j} \leftarrow&\sum \nolimits _{k \in \left\{ {S_t} \,\,\vert \,\, {\hat{j}}^k = j\right\} } {n_k},\end{aligned}$$
(2)
$$\begin{aligned} {\boldsymbol{w}^j_g(t+1)} \leftarrow&\sum \nolimits _{k \in \left\{ {S_t} \,\,\vert \,\, {\hat{j}}^k = j\right\} } \frac{{n_k}}{{n_j}}{\boldsymbol{w}^k(t+1)}. \end{aligned}$$
(3)

2.5 Federated Learning Using a Mixture of Experts

In order to construct a personalized model for each client, [24] first add a local expert model \({{f^k_l}(\boldsymbol{x}, {\boldsymbol{w}^k_l})}\) that is trained only on local data. Recall the global model \({{f_g}(\boldsymbol{x}, {\boldsymbol{w}_{g}})}\) from Sect. 2.3. The authors of [24] then learn to weigh the local expert model and the global model using a gating function from MoE [12, 15, 29]. The gating function takes the same input \(\boldsymbol{x}\) and outputs a weight for each of the expert models. It uses a Softmax in the output layer so that these weights sum to 1. We define \({{f_h^k}\left( \boldsymbol{x}, {\boldsymbol{w}^k_h}\right) }\) as the gating function for client k. The same model architectures are used for all local models, so \({{f_h^k}(\boldsymbol{x}, {\boldsymbol{w}}) = {f_h^{k'}}(\boldsymbol{x}, {\boldsymbol{w}})}\) and \({{f^k_l}(\boldsymbol{x}, {\boldsymbol{w}}) = {f^{k'}_l}(\boldsymbol{x}, {\boldsymbol{w}})}\) for all pairs of clients \({{k}, {k'}}\). For simplicity, we write \({{f_l}\left( \boldsymbol{x}\right) = {f^k_l}\left( \boldsymbol{x}, {\boldsymbol{w}^k_l}\right) }\) and \({{f_h}\left( \boldsymbol{x}\right) = {f_h^k}\left( \boldsymbol{x},{\boldsymbol{w}^k_h}\right) }\) for each client k. Parameters \({\boldsymbol{w}^k_l}\) and \({\boldsymbol{w}^k_h}\) are local to client k and not shared. Finally, the personalized inference is

$$\begin{aligned} {\hat{y}_h} = {f_h}\left( \boldsymbol{x}\right) {f_l}\left( \boldsymbol{x}\right) + \left[ 1 - {f_h}\left( \boldsymbol{x}\right) \right] {f_g}\left( \boldsymbol{x}\right) . \end{aligned}$$
(4)

3 Adaptive Expert Models for Personalization

3.1 Framework Overview and Motivation

In IFCA, after the training phase, the cluster model with the lowest loss on the validation set is used for all future inferences. All other cluster models are discarded in the clients. A drawback of IFCA is therefore that it does not use all the information available in the clients in form of unused cluster models. Each client has access to the full set of cluster models, and our hypothesis is that if a client can make use of all of these models we can increase performance.

It is sometimes advantageous to incorporate a local model, as in Sect. 2.5, especially when the local data distribution is very different from other clients. We therefore modify the MoE [24] to incorporate all the cluster models from IFCA [11] and the local model as expert models in the mixture, see Fig. 2. We revise (4) to

$$\begin{aligned} {\hat{y}_h} = {g_l}{f^k_l}\left( \boldsymbol{x}\right) + \sum _{j=0}^{{J}-1} {g_j^k}{f^j_g}\left( \boldsymbol{x}\right) , \end{aligned}$$
(5)

where \({g_l}\) is the local model expert weight, and \({g_j^k}\) is the cluster model expert weight for cluster j from \({f_h^k}\left( \boldsymbol{x}\right) \), see Fig. 2.

However, importantly, we note that setting J in [11] to a large value produces few cluster models that actually converge, which lowers performance when used in a MoE. The authors of [34] note that this method is difficult to train in practice and that the performance is worse than FedAvg together with fine-tuning. We differ from [11] in the cluster estimation step in that we select the same number of clients \({{K_s} = \left\lceil {C}{K} \right\rceil }\) in every communication round, regardless of J. This spreads out more evenly over the global cluster models. Since cluster models are randomly initialized we can end up updating one cluster model more than the others by chance. In following communication rounds, a client is more likely to select this cluster model, purely because it has been updated more. This also has the effect that as J increases, the quality of the updates are reduced as they are averaged from a smaller set of clients. In turn, this means that we needed more iterations to converge. Therefore, we make use of the \({\varepsilon }\)-greedy algorithm [31] in order to allow each client to prioritize gathering information (exploration) of the cluster models or use the estimated best cluster model (exploitation). In each iteration a client selects a random cluster model with probability \({\varepsilon }\) and the currently best otherwise, see Algorithm 3.

figure a
figure b

By using the \({\varepsilon }\)-greedy algorithm we make more expert models converge and avoid a mode collapse. We can then use the gating function \({f_h^k}\) from MoE to adapt to the underlying data distributions and weigh the different expert models. We outline our setup in Fig. 1 and provide details in Fig. 2 and Algorithms 1 to 4.

Fig. 2.
figure 2

Our solution with 2 global cluster models. Each client k has one local expert model \({f_l}(\boldsymbol{x},{\boldsymbol{w}^k_l})\) and share \({{J}=2}\) expert cluster models \({f^j_g}(\boldsymbol{x}, {\boldsymbol{w}^j_g})\) with all other clients. A gating model \({f_h}(\boldsymbol{x},{\boldsymbol{w}^k_h})\) is used to weigh the expert cluster models and produce a personalized inference \({\hat{y}_h}\) from the input \(\boldsymbol{x}\).

When a cluster model has converged it is not cost-effective to transmit this cluster model to every client, so by using per-model early stopping we can reduce communication in both uplink and downlink. Specifically, before training we initialize \({{\mathcal {J}}={\left\{ 1,2,\ldots ,J\right\} }}\). When early stopping is triggered for a cluster model we remove that cluster model from the set \({\mathcal {J}}\). The early-stopping algorithm is described in Algorithm 1.

4 Experiments

4.1 Datasets

We use three different datasets, with different non-IID characteristics, in which the task is an image multi-class classification task with varying number of classes. We summarize these datasets in Table 1.

  • CIFAR-10 [20], where we use a technique from [24] to create client partitions with a controlled Label distribution skew, see Sect. 4.2;

  • Rotated CIFAR-10 [11], where the client feature distributions are controlled by rotating CIFAR-10 images—an example of same label, different features;

  • Federated Extended MNIST (FEMNIST) [4, 5] with handwritten characters written by many writers, exhibiting many of the non-IID characteristics outlined in Sect. 2.2.

Table 1. Dataset summary statistics — number of samples per client.

4.2 Non-IID Sampling

In order to construct a non-IID dataset from the CIFAR-10 dataset [20] with the properties of class imbalance that we are interested in we first look at [27]. A pathological non-IID dataset is constructed by sorting the dataset by label, dividing it into shards of 300 data samples and giving each client 2 shards.

However, as in [24], we are interested in varying the degree of non-IIDness and therefore we assign two majority classes to each client which make up a fraction p of the data samples of the client. The remainder fraction \((1-p)\) is sampled uniformly from the other 8 classes. When \(p=0.2\) each class has an equal probability of being sampled. A similar case to the pathological non-IID above is represented by \(p=1\). In reality, p is unknown.

4.3 Model Architecture

We start with the benchmark model defined in [4] which is a Convolutional Neural Network (CNN) model with two convolutional layers and one fully connected layer with fixed hyperparameters. However, in our case where \({n_k}\) is small, the local model is prone to over-fitting, so it is desirable to have a model with lower capacity. Similarly, the gating model is also prone to overfitting due to both a small local dataset and the fact that it aims to solve a multi-label classification problem with fewer classes (expert models), than in the original multi-class classification problem. The local model, gating model and cluster models share the same underlying architecture, but therefore have hyperparameter such as number of filters in a hidden layer individually tuned, see Sect. 4.4. The AdamW [25] optimizer is used to train the local model and the gating model, while Stochastic Gradient Descent (SGD) [2] is used to train the cluster models to avoid issues related to momentum parameters when averaging. We use negative log-likelihood loss in (1).

4.4 Hyperparameter Tuning

Hyperparameters are tuned using [23] in four stages and used for all clients. For each model we tune the learning rate \({\eta }\), the number of filters in two convolutional layers, the number of hidden units in the fully connected layer, dropout, and weight decay. For the \({\varepsilon }\)-greedy exploration method we also tune \({\varepsilon }\).

First, we tune the hyperparameters for a local model and for a single global model. Thereafter, we tune the hyperparameters for the gating model using the best hyperparameters found in the earlier steps. Lastly, we tune \({\varepsilon }\) with two cluster models \({{J}=2}\). For the no exploration experiments we set \({{\varepsilon }=0}\).

Hyperparameters depend on p and J but we tune the hyperparameters for a fixed majority class fraction \({p}=0.2\), which corresponds to the IID case. The tuned hyperparameters are then used for all experiments. We show that our method is still robust in the fully non-IID case when \({{p}=1}\). See Table 2 for tuned hyperparameters in the CIFAR-10 experiment.

4.5 Results

We summarize our results for the class imbalance case exemplified with the CIFAR-10 dataset in Table 3. In Fig. 3, we see an example of how the performance varies when we increase the non-IID-ness factor p for the case when \({J}=3\). In Fig. 3a we see the performance of IFCA [11] compared to our solution in Fig. 3b. We also compare to: a local model fine-tuned from the best cluster model, an entirely local model, and an ensemble model where we include all cluster models as well as the local model with equal weights. In Fig. 4 we vary the number of cluster models J for different values of the majority class fraction p.

Table 2. Tuned hyper-parameters in the CIFAR-10 experiment for the global cluster models, the local models and the gating model.
Fig. 3.
figure 3

Results for CIFAR-10. Comparison between no exploration and our \({\varepsilon }\)-greedy exploration method for \({J}=6\). with \({\varepsilon }\)-greedy exploration is superior in all cases from IID to pathological non-IID class distributions, here shown by varying the majority class fraction p.

Table 3. Results for CIFAR-10 and \({p} \in \{0.2, 0.4, \ldots , 1\}\) when \({J}=6\). Mean \(\mu \) and standard deviation \(\sigma \) for our exploration method \({\varepsilon }\)-greedy and without exploration. We compare to the from IFCA [11]. Our proposed solution is superior in all but one case, indicated by bold numbers.

An often overlooked aspect of performance in FL is the inter-client variance. We achieve a smaller inter-client variance, shown for CIFAR-10 in Fig. 6a and Table 3.

We see that for CIFAR-10 our \({\varepsilon }\)-greedy exploration method achieves better results for lower values of p by allowing more of the cluster models to converge—thereby more cluster models are useful as experts in the MoE, even though the models are similar, see Fig. 5a. For higher values of p we see that the cluster models are adapting to existing clusters in the data, see Fig. 5c. The most interesting result is seen in between these extremes, see Fig. 5b. We note that the same number of clients pick each cluster model as in IFCA, but we manage to make a better selection and achieve higher performance.

Fig. 4.
figure 4

Results for CIFAR-10. Comparison between no exploration (colored dashed lines) and the \({\varepsilon }\)-greedy exploration method (colored solid lines). with the \({\varepsilon }\)-greedy exploration outperforms all other solutions, including the from IFCA [11]. It performs better the greater the non-IIDness, here seen by varying the majority class fraction p. Furthermore, our solution is robust to changes in the number of cluster models J.

For the rotated CIFAR-10 case in Table 4 and 7b we see that IFCA manages to assign each client to the correct clusters at \({{J}=2}\), and in this Same label, different features case our exploration method requires a larger J to achieve the same performance. We also note the very high \({{\varepsilon }=0.82}\). More work is needed on better exploration methods for this case.

Fig. 5.
figure 5

Results for CIFAR-10. The number of clients in each cluster for the different exploration methods. Clusters are sorted—the lowest index corresponds to the most picked cluster. Our \({\varepsilon }\)-greedy exploration method picks the cluster models more evenly.

Fig. 6.
figure 6

CDF of client accuracy. Comparison between no exploration (colored dashed lines) and the \({\varepsilon }\)-greedy exploration method (colored solid lines). with \({\varepsilon }\)-greedy exploration improves accuracy and fairness for two of the datasets.

The FEMNIST dataset represents a more difficult scenario since there are many non-IID aspects in this dataset. We find from Table 5 and Fig. 7a that for FEMNIST the best performance is achieved when \({{J}=9}\) and in Fig. 6b we show the distribution of accuracy for the clients for the different models.

Table 4. Results for Rotated CIFAR-10. Mean \(\mu \) and standard deviation \(\sigma \) with varying number of cluster models J, for our exploration method \({\varepsilon }\)-greedy and for the baseline exploration method from IFCA. At \({{ J}}\,{=}\,2\) all clients have picked the correct cluster model.
Table 5. Results for FEMNIST. Mean \(\mu \) and standard deviation \(\sigma \) with varying number of cluster models J for our exploration method \({\varepsilon }\)-greedy and for the baseline exploration method from IFCA.
Fig. 7.
figure 7

Results for FEMNIST and rotated CIFAR-10. Comparison between no exploration (colored dashed lines) and the \({\varepsilon }\)-greedy exploration method (colored solid lines). is superior in the FEMNIST case, but need more cluster models to achieve similar performance to the in the rotated CIFAR-10 case.

5 Related Work

The FedAvg algorithm [27] is the most prevalent algorithm for learning a global model in FL. This algorithm has demonstrated that an average over model parameters is an efficient way to aggregate local models into a global model. However, when data is non-IID, FedAvg converges slowly or not at all. This has given rise to personalization methods for FL [14, 18]. Research on how to handle non-IID data among clients is ample and expanding. Solutions include fine-tuning locally [33], meta-learning [9, 17], MAB [30], multi-task learning [22], model heterogeneous methods [7, 13], data extension [32], distillation-based methods [16, 21] and Prototypical Contrastive FL [28].

Mixing local and global models has been explored by [6], where a scalar \(\alpha \) is optimized to combine global and local models. In [29] the authors propose to use MoE [15] and learn a gating function that weighs a local and global expert to enhance user privacy. This work is developed further in [24], where the authors use a gating function with larger capacity to learn a personalized model when client data is non-IID. We differ in using cluster models as expert models, and by evaluating our method on datasets with different non-IID characteristics. Recent work has studied clustering in FL settings for non-IID data [3, 11, 19, 26]. In [11] the authors implement a clustering algorithm for handling non-IID data in form of covariate shift. Their proposed algorithm learns one global model per cluster with a central parameter server, using the training loss of global models on local data of clients to perform cluster assignment. In their work, they only perform clustering in the last layer and aggregate the rest into a single model. If a global model cluster is unused for some communication rounds, the global cluster model is removed from the list to reduce communication overhead. However, this means that a client cannot use other global cluster models to increase performance.

6 Discussion

We adapted the inspiring work by [11] to work better in our setting and efficiently learned expert models for non-IID client data. Sending all cluster models in each iteration introduces more communication overhead. We addressed this by removing converged cluster models from the set of selectable cluster models in Algorithm 1, although this is not used in our main results. This only affects the result to a minor degree, but has a larger effect on training time due to wasting client updates on already converged models. Another improvement is the reduces complexity in the cluster assignment step. A notable difference between our work and IFCA is that we share all weights, as opposed to only the last layer in [11]. These differences increase the communication overhead further, but this has not been our priority and we leave this for future work.

7 Conclusion

In this chapter, we have investigated personalization in a distributed and decentralized ML setting where the data generated on the clients is heterogeneous with non-IID characteristics. We noted that neither FedAvg nor state-of-the-art solutions achieve high performance in this setting. To address this problem, we proposed a practical framework of MoE using cluster models and local models as expert models and improved the adaptiveness of the expert models by balancing exploration and exploitation. Specifically, we used a MoE [24] to make better use of the cluster models available in the clients and added a local model. We showed that IFCA [11] does not work well in our setting, and inspired by the MAB field, added an \({\varepsilon }\)-greedy exploration [31] method to improve the adaptiveness of the cluster models which increased their usefulness in the MoE. We evaluated our method on three datasets representing different non-IID settings, and found that our approach achieve superior performance in two of the datasets, and is robust in the third. Even though we tune our algorithm and hyperparameters in the IID setting, it generalizes well in non-IID settings or with varying number of cluster models—a testament to its robustness. For example, for CIFAR-10 we see an average accuracy improvement of 29.78% compared to IFCA and 4.38% compared to a local model in the pathological non-IID setting. Furthermore, our approach improved the inter-client accuracy variance with 60.39% compared to IFCA, which indicates improved fairness, but 60.98% worse than a local model.

In real-world scenarios data is distributed and often displays non-IID characteristics, and we consider personalization to be a very important direction of research. Finding clusters of similar clients to make learning more efficient is still an open problem. We believe there is potential to improve the convergence of the cluster models further, and that privacy, security and system aspects provide interesting directions for future work.