1 Introduction

The task of unsupervised clustering aims to partition data into distinct categories using unsupervised methods. This approach provides a solution to the challenge of relying on a large amount of labeled data for classification tasks [33]. Traditional techniques, such as K-means [40] and probabilistic mixture models [10, 13,14,15, 44], have been developed for unsupervised clustering. However, directly clustering the original data often leads to suboptimal results due to the presence of numerous irrelevant factors that impact the clustering process [63]. Additionally, as the dataset size increases, the training time of the model also escalates rapidly.

Deep generative models, including autoencoders (AEs), variational autoencoders (VAE) [31], and generative adversarial networks (GANs) [21], have demonstrated remarkable success across diverse domains by effectively extracting meaningful information from raw data using compact latent spaces. As a result, these models have received significant attention and have been applied in various files [4, 25, 27, 37, 39, 50, 51]. In the realm of unsupervised clustering, VAEs, in particular, have been extensively employed. A VAE is a sophisticated type of deep generative model that combines variational inference [9, 11, 12] with deep neural networks. However, previous VAE-based clustering approaches optimized both the generation and clustering components within the same framework, leading to entanglement of various information in the latent variables, where clustering and generating factors interacted. Moreover, several powerful clustering models required pre-training with stacked autoencoders, imposing significant demands on computing resources and storage space. To address the aforementioned challenges, FSVAE [58] proposed a disentanglement strategy for the latent variable in D dimensions. This strategy involves dividing the latent variable into two parts. The initial \(D_1\) dimensions are assigned to the generation module, which follows a Gaussian prior distribution. The remaining \(D_2\) dimensions form the clustering module, which follows a Student’s t mixture model (STMM) prior [62]. It is important to note that \(D=D_1+D_2\), and the clustering module benefits from the use of augmented data to enhance its effectiveness.

The strategic use of augmented data can markedly improve clustering accuracy. Prior models [29, 57] have demonstrated a simplistic or constrained approach to leveraging augmented data, thereby limiting their effectiveness in guiding clustering tasks. In contrast, FSVAE utilizes augmented data more robustly, applying basic constraints to the latent space, cluster assignment variables, and encoders. However, these models predominantly depend on the mean square error loss function, which restricts latent and cluster assignment variables to numerical similarities. Our proposed method advances this by employing mutual information maximization. It uniquely uses cluster assignment variables to shape latent variables, thereby generating innovative latent constructs. This technique enriches the information used during training for both cluster assignment and latent variables, potentially enhancing clustering performance.

The limitations of previous models, specifically their underutilization of augmented data, have led to suboptimal clustering effectiveness. To address this, our research introduces a novel unsupervised clustering model, focused on image clustering, leveraging the FSVAE framework. Our primary innovation involves the application of mutual information maximization, as outlined in Ji et al [28]. Through a streamlined objective function, we aim to enhance clustering effectiveness by maximizing mutual information between cluster assignment variables in both original and augmented data. Further, by leveraging augmented data and cluster assignment variables, our model optimizes the latent variables, as suggested in Haeusser et al [24]. This optimization not only extracts more informative features in the latent space but also significantly improves clustering performance. Finally, our approach introduces a unique method that merges the clustering module of the original data with the generation module of the augmented data. This synthesis results in a new type of latent variable, enabling the model to ignore extraneous information not pertinent to clustering, thereby further refining clustering effectiveness.

The main contributions of this work can be summarized as follows:

  • We introduce the mutual information maximization technique to optimize the model’s cluster assignment variables by the augmented data, thereby improving the model’s robustness and clustering effectiveness.

  • We constrain latent variables by using augmented data and clustering assigning variables, thereby improving the robustness of the model and the effectiveness of clustering.

  • We propose a new latent variable construction method to help the model ignore irrelevant information, thereby improving the robustness of the model and the effectiveness of clustering.

The remainder of this paper is organized as follows. Section 2 presents an overview of existing work that utilizes the proposed model. Section 3 provides the necessary background knowledge related to our model. In Sect. 4, we introduce our deep generative clustering method for VAE, which includes disentangling modules and the utilization of augmented data. The experimental results obtained using our proposed model are presented in Sect. 5. Finally, Sect. 6 concludes the paper, summarizing the main findings and contributions.

2 Related works

Unsupervised deep clustering models have exhibited impressive performance in various clustering tasks. Among these models, deep unsupervised generative clustering approaches, such as AE, VAE, and GAN, have achieved notable success in both clustering and data reconstruction tasks.

One widely recognized AE-based deep generative clustering model is Deep Embedding Clustering (DEC) [54]. DEC employs the K-means algorithm in the pre-training stage to generate cluster centers for each cluster. These cluster centers are then iteratively optimized using an auxiliary distribution. However, DEC only utilizes a complete AE model during the pre-training stage, discarding the decoder in the subsequent training stage, and solely utilizing the encoder to optimize the cluster centers. Improved Deep Embedding Clustering (IDEC) [22] addresses the limitations of DEC by simultaneously training the reconstruction loss and clustering loss in the AE framework. This joint optimization enhances the assignment of clustering labels, enables learning of discriminative clustering features, and preserves the local structure. Adding the decoder back to the model during training leads to improved clustering results. Another DEC-based model, Deep Convolutional Embedding Clustering (DCEC) [23], incorporates convolutional neural networks (CNNs) to enhance clustering performance through improved feature extraction capabilities. The DSSEC model, as detailed in Cai et al [1], innovatively combines a sparse autoencoder with DEC for enhanced cluster analysis. Concurrently, the DCN algorithm, introduced in Yang et al [56], represents an autoencoder (AE)-based clustering approach that principally utilizes the K-means algorithm for cluster formation. A notable development in this field is the DEPICT model [20], which marks a significant advancement by employing a softmax layer, derived from stacked AE pre-training, for predicting cluster assignments, thereby yielding considerable improvements in performance. Furthermore, the DECCA framework, as proposed in Diallo et al [6], adopts a contractive learning methodology to cultivate more effective latent variables, ACe/DeC model [46] ventures into categorizing the information gleaned by latent variables, effectively distinguishing between cluster-specific and shared information spaces, TSAE model [16] integrates Teacher-Student models and Autoencoders for cluster analysis. Lastly, the IMDGC [61] integrates hierarchical generative adversarial networks and mutual information maximization to improve clustering effectiveness.

Variational Deep Embedding (VaDE) [29] is a deep generative clustering model based on the VAE framework. Unlike VAE, which utilizes a standard Gaussian distribution as the prior, VaDE employs a Gaussian mixture model as the prior distribution. S3VDC [2] is a generative clustering model that builds upon VaDE. It incorporates initial \(\gamma\) training to optimize the pre-training of VaDE. Drawing inspiration from \(\beta\)-VAE [25], S3VDC introduces periodic \(\beta\) annealing to promote disentanglement in the VaDE model, allowing it to capture more informative representations. S3VDC also adopts mini-batch Gaussian Mixture Model (GMM) initialization to enhance scalability and employs inverse min-max transformation to mitigate NaN (Not-a-Number) issues during training. The vMF-VaDE model, as elucidated by Yang et al. [59] has demonstrated exceptional performance on clustering across various datasets. FSVAE [58] extends the VAE model by introducing a distinct treatment of clustering and generation information in the data. It utilizes the Student’s t-Mixture model as the prior for the clustering module. FSVAE incorporates a bi-augmentation module to enhance training stability and notably achieves optimal performance without the need for pre-training. Other VAE-based generative clustering models include GMVAE [7] and DSCDAN [60]. In addition to the previously mentioned models, the field of deep generative clustering encompasses GAN-based approaches such as clusterGAN [47] and Va-GAN [57]. These models integrate a GAN component to enhance clustering effectiveness. Another notable model is Dual-AAE [19], which extends the Adversarial Autoencoder (AAE) [42] framework.

Autoencoder-based models including DEC, IDEC, DCEC, DCN, DEPICT, DECCA, and ACe/DeC, have shown notable success in clustering, with a primary emphasis on feature extraction. Unlike these models, VAE-based clustering approaches, such as VaDE, S3VDC, and GMVAE, generally incorporate Gaussian mixture models as priors, thereby enhancing clustering effectiveness. However, a limitation arises in these models’ susceptibility to collapse, particularly when not utilizing the Student’s t-distribution. For instance, the vMF-VaDE model, despite achieving optimal results, necessitates pre-training and is especially prone to collapse due to its reliance on the vMF distribution. In contrast, GAN-based models like clusterGAN, Va-GAN, and Dual-AAE leverage adversarial networks and employ strategies such as WGAN-GP to mitigate model collapse.

The primary limitation of the aforementioned models is their methodology of performing cluster analysis on the entire latent variable space. Differing from this approach, FSVAE introduces an innovative disentanglement strategy, segregating latent variables into two distinct categories: clustering variables and generated variables. Furthermore, FSVAE employs a Student’s t-mixture model to prevent model collapse, a technique that has proven to significantly improve clustering effectiveness. Despite these advancements, FSVAE relies primarily on a simple mean square error loss function to direct the clustering of latent variables using augmented data, which could be seen as a constraint on its potential. In contrast, our model expands the utility of augmented data by implementing mutual information maximization. This enhancement not only constrains latent variables but also facilitates the construction of new latent variables, thereby allowing augmented data to more effectively guide the clustering process within the model.

Motivated by the works of Ji et al [28]; Haeusser et al [24], we incorporate the technique of maximizing mutual information in our generative model and introduce constraints on the latent variables. Furthermore, building upon the FSVAE framework mentioned earlier, we propose a novel method for constructing latent variables that allows the model to disregard irrelevant information for clustering. These three approaches collectively enhance the robustness and effectiveness of our model.

3 Preliminary

3.1 Variational autoencoder

The VAE is a neural network architecture that employs a Gaussian distribution as the prior for the latent space. This choice provides VAE with enhanced capabilities compared to the AE, such as the ability to estimate prediction uncertainties [17]. The primary objective of optimizing the VAE model can be formulated as follows:

$$\begin{aligned} \mathcal{L}(\theta ,\phi ;x)= & {} {E_{q\phi (z|x)}}[\log {p_\theta }(x|z)]\nonumber \\{} & {} - {D_{KL}}({q_\phi }(z|x)||{p_\theta }(z)), \end{aligned}$$
(1)

where \({\mathbb {E}}[\cdot ]\) denotes the expectation evaluation. In Eq. 1, the first term is commonly referred to as the reconstruction error, while the second term represents the regularization term, defined by the Kullback–Leibler (KL) divergence.

3.2 VAE with Gaussian mixture prior

The Gaussian mixture model has found applications in various fields [32, 38, 49] and can be utilized as a prior for the VAE.

For a given dataset X, we assume that the data x is generated by a random process. Moreover, for any potential embedding z of the data x, we consider it to follow a Gaussian Mixture Model (GMM) with K clusters. This GMM serves as the prior for the VAE model [29]. In the VAE’s generation process with a GMM prior, we can generate z from the GMM distribution, which is defined by

$$\begin{aligned} p(z|y) = \mathcal{N}(z|{\mu _y},\sigma _y^2), \end{aligned}$$
(2)

where \(\mu _y\) and \(\sigma _y^2\) represent the parameters of the Gaussian distribution for cluster y. The variable y follows a categorical distribution \(Cat(\pi )\), where \(p(y) = Cat(\pi )\), and \(Cat(\pi )\) denotes the categorical distribution parameterized by \(\pi\).

By employing the GMM as the prior for the VAE, the final optimization objective can be expressed as:

$$\begin{aligned} \begin{aligned} \mathcal{L}\left( x \right)&= \frac{1}{L}\sum \limits _{l = 1}^L {\sum \limits _{d = 1}^D {{x_d}\log (\mu _{xd}^{(l)} + \sigma _{xd}^{2(l)}\varepsilon _d^{(l)})} } \\&\quad + (1 - {x_d})\log (1 - (\mu _{xd}^{(l)} + \sigma _{xd}^{2(l)}\varepsilon _d^{(l)}))\\&\quad - \frac{1}{2}\sum \limits _{y = 1}^K q(y|x)\\&\quad \sum \limits _{j = 1}^J \left( \log \sigma _{yj}^2 + \frac{{\overrightarrow{\sigma }_j^2}}{{\sigma _{yj}^2}} + \frac{{{{({{\overrightarrow{\mu }}_j} - {\mu _{yj}})}^2}}}{{\sigma _{yj}^2}}\right) \\&\quad + \sum \limits _{y = 1}^K {q(y|x)\log \frac{{{\pi _y}}}{{q(y|x)}}}\\&\quad + \frac{1}{2}\sum \limits _{j = 1}^J {(1 + \log {{\overrightarrow{{\sigma _j}} }^2})}, \end{aligned} \end{aligned}$$
(3)

where L represents the number of Monte Carlo samples. \(\overrightarrow{\mu }\), \(\overrightarrow{\sigma }^{2}\) are parameters of the Gaussian distribution modeled by the encoder. D denotes the dimensionality of x, \(\mu _{x}^{(l)}\) and \(\sigma _{x}^{(l)}\). J is the dimensionality of the parameters \(\mu _{y}\), \(\sigma _{y}\), \(\overrightarrow{\mu }\), and \(\overrightarrow{\sigma }^{2}\). The cluster assignment variable q(y|x) can be obtained using the SGVB estimator as:

$$\begin{aligned} q(y|x) = \frac{1}{L}\sum \limits _{l = 1}^L {\frac{{p(y)\mathcal{N}\left( {{z^{(l)}}|{\mu _y},\sigma _y^2} \right) }}{{\sum \nolimits _{{y^{'}} = 1}^K {p({y^{'}})\mathcal{N}\left( {{z^{(l)}}|{\mu _{{y^{'}}}},\sigma _{{y^{'}}}^2} \right) }}}}. \end{aligned}$$
(4)

where \(\mathcal{N}(\cdot )\) represents the Gaussian distribution, and p(y) is the prior probability of cluster y.

3.3 VAE with student’s t mixture prior

The Student’s t mixture model is widely employed across diverse domains [5, 18, 45] and offers a robust alternative to the Gaussian mixture model, particularly when handling outliers.

In Sect. 3.2, we introduced a deep generative clustering model that employed a GMM as the prior for the VAE. While this approach demonstrated significant clustering capabilities, it suffered from performance degradation in the presence of outliers. To overcome this limitation, we utilize the Student’s t distribution as a more robust alternative, owing to its heavy-tailed characteristics. By incorporating the Student’s t mixture model (STMM) into the VAE framework, we can enhance the model’s robustness [62]. The probability density function (PDF) of the Student’s t distribution is defined by

$$\begin{aligned} \mathcal{S}(x|\mu ,{\sigma ^2},v)= & {} \frac{{\varGamma (\frac{{v + 1}}{2})}}{{\varGamma (\frac{v}{2})\sqrt{\pi v} \sigma }}\nonumber \\{} & {} {\left( 1 + \frac{{{{(z - \mu )}^T}{\sigma ^{ - 2}}(z - \mu )}}{v}\right) ^{ - \frac{{v + 1}}{2}}}, \end{aligned}$$
(5)

where v, \(\mu\), and \(\sigma\) are the distribution parameters, and \(\varGamma\) denotes the gamma function. We adopt the reparameterization trick for the Student’s t distribution [43, 48],which is similar to that of the Gaussian distribution:

$$\begin{aligned} z = \mu + \sqrt{\frac{v}{{2{{\widetilde{z}}}}}} \sigma \cdot \varepsilon , \end{aligned}$$
(6)

where \(\varepsilon \sim \mathcal{N}\left( {0,1} \right)\), and \({{\widetilde{z}}}\sim \mathcal{G}(\frac{v}{2},1)\). Here, \({{\widetilde{z}}}\) is obtained using the reparameterization trick for the Gamma distribution as

$$\begin{aligned} {{\widetilde{z}}} = \left( \frac{v}{2} - \frac{1}{3}\right) {\left( 1 + \frac{\varepsilon }{{\sqrt{\frac{{9v}}{2} - 3} }}\right) ^3}, \end{aligned}$$
(7)

where \(\varepsilon \sim \mathcal{N}\left( {0,1} \right)\).

In the VAE model using STMM as the prior, a notable distinction from Sect. 3.2 is the assumption that z is generated by STMM, thus conforming to the Student’s t distribution in the latent space. This assumption necessitates the encoder to provide an additional parameter v, which is unique to the Student’s t distribution.

4 The proposed model

This section outlines our clustering model, which integrates VAE with disentangled representations to enhance clustering performance. Figure 1 provides a visual representation of the model’s architecture. The methodology is systematically detailed across several subsections: Sect. 4.1 delves into the process of disentangling latent variables within the model. Section 4.2 describes the bi-augmentation modules employed in FSVAE, enhancing its efficacy. In Sect. 4.3, we articulate our approach for maximizing mutual information, ensuring a robust correlation between original and augmented data. Section 4.4 explains how we impose constraints on the latent variables using cluster assignment variables and augmented data to refine the clustering process. Finally, Sect. 4.5 presents our innovative latent variables, offering both the mathematical formulation and the algorithmic structure of the model, underscoring its practical and theoretical contributions.

Fig. 1
figure 1

The network architecture of the proposed model. Initially, we employ both feature augmentation and data augmentation techniques to enrich the training dataset. The model uniquely integrates a disentanglement strategy for latent variables, segregating them into two modules: \(Z_{g}\) for general features and \(Z_{c}\) for cluster-specific features. To enforce constraints, we utilize mean square error loss on both Z and \({\tilde{Z}}\), the latter representing augmented latent variables. Similar constraints are also applied to the cluster assignment variables. The function \({\mathcal {F}}_c\) is designated for maximizing mutual information between the original and augmented data, while \({\mathcal {F}}_z\) focuses on constraining the latent variables. Additionally, we introduce a novel latent variable, synthesized by combining the latent variables of both the original and augmented images, which is depicted at the top of the figure

4.1 Disentanglement of latent representations

Various disentangling methods have been proposed in the past, aiming to ensure that a single latent variable corresponds to a single factor [3, 8, 25, 30, 64]. In this subsection, we provide a comprehensive elucidation of the methodology employed to disentangle the model’s latent variables into two distinct modules: clustering and generative. Additionally, we present the training formulas specifically tailored for the disentangled clustering model.

In the VAE model, the presence of a decoder responsible for generating reconstructed data suggests that the latent variable z contains valuable information for the generation process. Expanding on this idea, we can consider the latent variable z as composed of two distinct modules: the clustering module and the generation module [58]. The generation module \(z_g\) follows a Gaussian distribution, while the clustering module \(z_c\) aligns with the description in Sect. 3.3, utilizing the Student’s t Mixture Model (STMM) as its prior. The concatenation operation \(\oplus\) combines the two modules as follows:

$$\begin{aligned} z = {z_g} \oplus {z_c}. \end{aligned}$$
(8)

By allowing each module to fulfill its respective role, the clustering effectiveness and generation capabilities of the model can be enhanced.

Based on the discussions in Sects. 3.13.2, and 3.3, the final optimization objective of the model can be expressed as

$$\begin{aligned} \begin{aligned} \mathcal{L}_{fs}\left( x \right)&= {E_{q({z_g},{z_c},y|x)}}\left[ \log \frac{{p(x,{z_g},{z_c},y)}}{{q({z_g},{z_c},y|x)}}\right] \\&= {E_{q({z_g},{z_c},y|x)}}[\log p(x|{z_g},{z_c})] \\&\quad - {D_{KL}}(q({z_g}|x)||p({z_g}))\\&\quad - {\mathrm{}}{D_{KL}}(q({z_c},y|x)||p({z_c},y)), \end{aligned} \end{aligned}$$
(9)

where the optimization objective consists of three parts: the reconstruction error, the loss of the generation module, and the loss of the clustering module with the STMM as the prior.

4.2 Bi-augmentation module

In this subsection, we present a thorough analysis of the bi-augmentation modules implemented in the FSVAE. These modules play a pivotal role in augmenting the clustering capabilities of the model.

The bi-augmentation module consists of two components: feature augmentation and data augmentation, both of which play a crucial role in the clustering process. Firstly, we introduce a transformed image, denoted as \({\widetilde{x}}\), into the model. Since x and \({\widetilde{x}}\) represent the same image, we assume that their corresponding latent variables, \({\widetilde{z}}\) and z, are similar or identical. To align z and \({\widetilde{z}}\), data augmentation is applied. Similarly, data augmentation is also performed on the cluster assignment variables \(\gamma =q(y|x)\) and \({\widetilde{\gamma }}=q(y|{\widetilde{x}})\). The loss function for data augmentation can be formulated as follows:

$$\begin{aligned} {\mathcal{L}_{aug}} = {\mathcal {C}}({z_c},{{{\widetilde{z}}}_c}) + {\mathcal {C}}({z_g},{{{\widetilde{z}}}_g}) + {\mathcal {C}}(\gamma ,{{\widetilde{\gamma }}}), \end{aligned}$$
(10)

where \({\mathcal {C}}\) denotes the mean square error loss function.

Secondly, we enhance the model by splitting the CNN into two parts and incorporating feature normalization [35, 36] on the output of the preceding CNN segment. The feature normalization operation is defined as:

$$\begin{aligned} {{\overline{h}}} = {{\widehat{\mu }}} + {{\widehat{\sigma }}} \frac{{\overrightarrow{h} - \overrightarrow{\mu }}}{{\overrightarrow{\sigma }}}, \end{aligned}$$
(11)

where \({\overline{h}}\) represents the input to the subsequent CNN segment, while \(\overrightarrow{h}\) denotes the output of the image from the preceding CNN segment. The parameters \(\overrightarrow{\mu }\) and \(\overrightarrow{\sigma }\) correspond to the mean and variance of \(\overrightarrow{h}\), respectively. Similarly, the augmented image undergoes the same processing to obtain \({\widehat{\mu }}\) and \({\widehat{\sigma }}\) for feature normalization.

4.3 Augmented mutual information

In this subsection, we delve into a comprehensive explanation of the techniques employed to maximize mutual information between the original and augmented data. By emphasizing the synergy between original and augmented data, we aim to elucidate how this strategy significantly enhances the model’s performance.

While the augmentation techniques discussed in Sect. 4.2 provide some assistance for clustering, there remain several deficiencies that require improvement. In Sect. 4.2, the data augmentation module is employed to enforce a mean square error loss between the cluster assignment variables of the original and augmented data. This aligns with the intuitive perception that the original and augmented data represent the same underlying information, and therefore, their cluster assignment variables should be consistent. Inspired by the work of Ji et al. [28], we posit that maintaining consistency in the information of cluster assignment variables between original and augmented data is crucial. Consequently, our research further investigates the maximization of mutual information in this context. Specifically, our focus lies in enhancing the mutual information of the cluster assignment variables. This approach is designed to facilitate the discovery of more refined and effective cluster assignment representations, thereby improving the overall clustering performance of our model.

In our model, for each data point \(x_i\) in the dataset \(X=\{x_1, x_2,..., x_N\}\), we obtain the probability of data \(x_i\) belonging to each category using Equation (4). The augmentation of mutual information aims to maximize the alignment of cluster assignment variables between the original and augmented data. The augmented mutual information loss is defined by

$$\begin{aligned} {\mathcal{L}_I} = \max I(\gamma , {\gamma ^{'}}), \end{aligned}$$
(12)

where \(\gamma\) represents the cluster assignment variable for the original data and \(\gamma ^{'}\) represents the cluster assignment variable for the augmented data. Since \(\gamma\) can be treated as a discrete random variable distributed across K categories, we can directly compute \(I({\gamma },\gamma ^{'})\) as

$$\begin{aligned} I(\gamma ,{\gamma ^{'}}) = \sum \limits _{y = 1}^K {\sum \limits _{y' = 1}^K {{P_{y{y^{'}}}} \cdot \log \frac{{{P_{y{y^{'}}}}}}{{{P_y} \cdot {P_{y'}}}}} }. \end{aligned}$$
(13)

For any two pairs of samples \((x, x')\), the conditional joint distribution is given by \(P(\gamma = y,\gamma ^{'} = y^{'}|x,{x^{'}}) = {\gamma ^{y}}{\gamma ^{y^{'}}}\). By marginalizing over the dataset, we can obtain the joint distribution P of the cluster assignment variables using the following equation

$$\begin{aligned} P = \frac{1}{n}\sum \limits _{i = 1}^n {{\gamma _i} \cdot {\gamma _i}^{{'^{T}}}}, \end{aligned}$$
(14)

where n denotes the number of sample pairs.

We considered the order of sample pairs as \((x, x')\). However, it is also valid to consider the order as \((x', x)\). Consequently, we obtain the following joint distribution:

$$\begin{aligned} {P_{y{y^{'}}}} = (P + {P^{T}})/2, \end{aligned}$$
(15)

where the marginal distributions \({P_y}\) and \({P_{y'}}\) can be obtained by summing the rows and columns of \(P_{y{y'}}\). Maximizing the mutual information of the cluster assignment variables helps to identify the similarities between samples, thereby improving the accuracy of clustering.

4.4 Augmenting latent variables

In Sect. 4.3, we discussed the constraints imposed on the cluster assignment variables. While acknowledging the utility of this approach, we recognize that the application of mean square error loss to latent variables, in isolation, may not yield substantial improvements in clustering performance. Inspired by Haeusser et al [24], this subsection introduces a strategy to apply more stringent constraints on the latent variables of the data. This enhanced approach is aimed at further refining the model’s clustering capabilities, drawing upon advanced methodologies to achieve more significant improvements.

For a sample pair \((x_i, x_j)\), we construct the loss as follows:

$$\begin{aligned} {\mathcal{L}_{trafo}} = |1 - z_{c,i}^{T}z_{c,j}^{'} - {\ell _2}({\gamma _i},\gamma _j^{'})|, \end{aligned}$$
(16)

where \({\ell _2}\) represents the cross-entropy loss function. Since we only constrain the clustering modules in the latent variables, we refer to them as \(z_c\) in the equation. The latent variable \(z_c\) is derived from the original data, while \(z_c'\) is obtained from the augmented data.

Unlike in Sect. 4.3, where the sample pair consists of the original data and its augmented counterpart, \({\mathcal{L}_{trafo}}\) has two optimized forms. In the first case, we set \(x_j\) to be the augmented data of \(x_i\). In this scenario, when the cluster assignment variables are similar, the latent variables \({z_i}\) and \({z_j}\) of the image become more alike, thereby imposing a stronger constraint on the latent variables. In the second case, we set \(x_i\) and \(x_j\) to correspond to different data and optimize \({\mathcal{L}_{trafo}}\). By doing so, we encourage the latent variables \({z_i}\) and \({z_j}\) to be similar when their cluster assignment variables are similar. However, if the cluster assignment variables are dissimilar, the latent variables \({z_i}\) and \({z_j}\) will be different. Since we believe that cluster assignment variables belonging to the same category should be similar, this approach encourages the latent variables of the same category to be more compact, while keeping the latent variables of different categories more separated.

Table 1 Experimental results of two different tarfo loss on the MNIST dataset

We conducted a comparison of clustering accuracy and NMI (Normalized Mutual Information) for the two cases on the MNIST dataset, and the results are presented in Table 1. In the table, “Baseline” refers to the experimental results without utilizing the trafo loss, “Case1” represents the experimental results of the first case where trafo loss is applied to the baseline model, and “Case2” corresponds to the experimental results of the second case where trafo loss is applied to the baseline model. The table clearly demonstrates that the trafo loss in the first case yields significantly better clustering performance compared to the second case. For the sake of brevity, we rewrite equation (16) in the form of the first case as

$$\begin{aligned} {\mathcal{L}_t} = |1 - {z_{c}^{T}}{z_{c}^{'}} - {\ell _2}(\gamma ,{\gamma ^{'}})|. \end{aligned}$$
(17)

4.5 Latent variable construction

In this subsection, we provide a comprehensive description of our method for constructing new latent variables for training purposes. We also briefly discuss the rationale behind this approach and present a figure illustrating the process of creating these new latent variables. Additionally, the final formulation and algorithm of the model are outlined.

In the previous section, we employed the latent variable z, derived from the original data via the encoder, for reconstruction in the decoder. This method operates under the assumption that z contains ample information for effective clustering, while it overlooks the potential contributions of augmented data. However, using augmented data for reconstruction might inadvertently incorporate irrelevant information into z, which could impede clustering performance. For example, in cases where simple rotations are used for data augmentation, the latent variable may inadvertently encode rotation-related information. We hypothesize, though, that if the latent variables can assimilate a modest amount of additional information, the model could, during its training phase, learn to autonomously disregard this extraneous data along with other irrelevant factors, thus enhancing clustering accuracy. As illustrated in Fig. 2, we suggest substituting the generation component associated with the original data’s latent variable with that of the augmented data’s latent variable for reconstruction. This alteration enables the latent variable to encompass a broader range of information, facilitating its ability to eliminate non-essential elements for clustering and consequently improving the overall efficiency of the clustering model.

Fig. 2
figure 2

New latent variables constructed on all datasets

We express our newly constructed latent variables as

$$\begin{aligned} {{\overline{z}}} = {{{\widetilde{z}}}_g} \oplus {z_c}. \end{aligned}$$
(18)

Similar to the latent variables in the original model, we utilize \({\overline{z}}\) in the decoder to obtain \({\overline{x}}\), and then compute the cross-entropy loss between \({\overline{x}}\) and x. Our objective function is defined as

$$\begin{aligned} {\mathcal{L}_r} = \ell (x,{{\overline{x}}} ), \end{aligned}$$
(19)

where \(\ell\) represents the cross-entropy loss function. Notably, our proposed method for constructing latent variables exclusively relies on augmented data, making it applicable to other models with disentangled latent variables.

The overall loss function of the model can be expressed as

$$\begin{aligned} {\mathcal{L}_{total}} = - {\mathcal{L}_{fs}}(x) + {\mathcal{L}_{aug}} + {\lambda _I}{\mathcal{L}_I} + {\lambda _t}{\mathcal{L}_t} + {\lambda _r}{\mathcal{L}_r}, \end{aligned}$$
(20)

where the hyperparameter \(\lambda _I, \lambda _t, \lambda _r\) is used to tune the contribution of the loss term \({\mathcal {L}}_I, {\mathcal {L}}_t, {\mathcal {L}}_r\).

The algorithm of our model through the SGVB estimator is shown in Algorithm 1.

Algorithm 1
figure a

Training steps

5 Experimental results

In this section, we present the experimental results obtained from evaluating our approach on five distinct image datasets. We assess the performance using two metrics. The hardware setup used for the experiments includes an Intel i7-11700 CPU running at 2.50GHz, a GeForce RTX 3060 GPU, and 32GB of memory. The software environment consists of Python version 3.6 and PyTorch version 1.10.0.

5.1 Datasets

MNIST: LeCun et al [34] The MNIST dataset comprises 70,000 handwritten digit images, featuring 10 distinct classes ranging from ’0’ to ’9’. The resolution is 28\(\times\)28 pixels. To augment the dataset, we apply random rotations within the range of -25 to 25 degrees, adjust the image contrast, and use PIL.ImageChops.offset() to randomly shift the image by up to 0.35 times its size.

USPS: The USPS dataset comprises 9,298 handwritten digit images, covering the numbers ’0’ to ’9’. The resolution is 16\(\times\)16 pixels. Unlike the MNIST dataset, this dataset includes random rotations ranging from -35 to 35 degrees.

GTSRB: Houben et al [26] The GTSRB dataset comprises images of 43 different categories of traffic signs. For our experiments, we select 10 categories, resulting in a total of 15,540 images. The resolution is 28\(\times\)28 pixels. Similar to the USPS dataset, the GTSRB dataset includes random rotations ranging from -180 to 180 degrees.

YTF: Wolf et al [52] The YTF dataset comprises 4,733 face images in 20 categories. The resolution is 32\(\times\)32 pixels. The augmented image method used in this dataset aligns with the approach employed for the GTSRB dataset.

Fashion-MNIST: Xiao et al [53] The Fashion-MNIST dataset comprises 70,000 images, encompassing 10 categories of fashion items such as T-shirts, coats, sandals, etc. The resolution is 28\(\times\)28 pixels. The augmentation procedure employed for this dataset aligns with that of the MNIST dataset.

Detailed information about these datasets is shown in Table 2.

Table 2 Summary of the benchmark datasets

5.2 Implementation details

We evaluate the performance of our approach using two metrics: clustering accuracy (ACC) and normalized mutual information (NMI) [55]. Our model is implemented in PyTorch and optimized using the Adam optimizer with hyperparameters \(\beta _{1}=0.5\) and \(\beta _{2}=0.99\). The learning rate for all datasets is set to 2e-3 and decays by \(95\%\) every 10 epochs to prevent overfitting. We use a batch size of 64 and train the model for 500 epochs. The network architecture of our model is summarized in Table 3, and the distribution of latent variables across different datasets is provided in Table 4.

Table 3 Network settings for all datasets
Table 4 Latent dimension for all data sets

5.3 Experimental results

Table 5 The ACC results by different methods without pre-training
Table 6 The NMI results by different methods without pre-training

In this experiment, we compare our method, which does not involve pre-training, to several deep clustering algorithms, including K-means [40], GMM [44], AE+K-means, DEC [54], IDEC [22], GMVAE [7], ClusterGAN [47], S3VDC [2], DECCA [2], TSAE [16] and IMDGC [61]. Additionally, we assess the effectiveness of our newly added loss by comparing it to the baseline model FSVAE [58]. The experimental results across multiple experiments are reported in Tables 5 and 6. With the exception of ACC and NMI on the F-MNIST dataset, our model achieves the highest performance in terms of ACC and NMI among the models without pre-training. Moreover, it outperforms the baseline model FSVAE, demonstrating the effectiveness of our newly added loss.

Moreover, we conduct a comprehensive comparison between our proposed model and nine advanced clustering methods that incorporate pre-training. These methods include DCEC [23], DCN [56], VaDE [29], Va-GAN [57], DEPICT [20], Dual-AAE [19], DSCDAN [22], ACe/Dec [46], and vMF-VaDE [59]. The comparative results are detailed in Table 7. Notably, the performance metrics within parentheses correspond to results achieved without the use of pre-training. Our model not only matches but in some instances, surpasses the clustering accuracy of these pre-training based methods across various datasets. This underlines the robustness and effectiveness of our approach, even in the absence of pre-training.

Table 7 The performance comparison of different methods with or without the pre-training step, where the values in brackets denote the performance obtained without using pre-training
Fig. 3
figure 3

Results of clustering visualization on the MNIST dataset. We selected epoch0, epoch200 and the final results for display

Fig. 4
figure 4

This figure shows the ACC and NMI of the model under different hyperparameters on the MNIST dataset

Fig. 5
figure 5

This figure shows the ACC and NMI of the model under different hyperparameters on the USPS dataset

To visualize the separation of cluster assignments among latent variables during training, we utilize t-SNE [41] on the entire MNIST dataset, as shown in Fig. 3.

In Equation (20), our newly proposed loss function incorporates three hyperparameters: \({\lambda _{I}}\), \({\lambda _{t}}\), and \({\lambda _{r}}\). We assessed the impact of various hyperparameter settings on the model’s Accuracy (ACC) and Normalized Mutual Information (NMI) using the MNIST and USPS datasets, as illustrated in Figs. 4 and 5.

In Fig. 4, our investigation into the MNIST dataset involved adjusting all three hyperparameters within the range of 0.0001 to 1. For \({\lambda _{I}}\) and \({\lambda _{t}}\), we observed minimal variations in performance in the 0.0001 to 0.1 range, followed by a consistent increase, achieving optimal performance at a setting of 1. Based on these findings, we selected \({\lambda _{I}}\)=1 and \({\lambda _{t}}\)=1 for subsequent experiments. In contrast, \({\lambda _{r}}\) displayed a decreasing trend in performance, leading us to choose \({\lambda _{r}}\)=0.01 for further experimentation. Regarding the USPS dataset, as depicted in Fig. 5, we similarly adjusted the hyperparameters from 0.0001 to 1. Across the board, an increase in hyperparameter values corresponded with an upward trend in the model’s performance. Consequently, we set all three hyperparameters to 1 for subsequent analyses.

To further analyze the impact of our two newly added losses on the model’s clustering performance, we conduct an ablation experiment, and the results are presented in Table 8. In the table, “Baseline” refers to the original model, “Ours w/o \({\mathcal {L}}_t\)+\({\mathcal {L}}_r\)” indicates the absence of both augmented latent variables and the construction of new latent variables, “Ours w/o \({\mathcal {L}}_I\)+\({\mathcal {L}}_r\)” indicates the absence of both augmented mutual information and the construction of new latent variables, “Ours w/o \({\mathcal {L}}_I\)+\({\mathcal {L}}_t\)” indicates the absence of both augmented mutual information and augmented latent variables, “Ours w/o \({\mathcal {L}}_r\)” indicates the absence of constructing new latent variables, “Ours w/o \({\mathcal {L}}_t\)” indicates the absence of augmented latent variables, “Ours w/o \({\mathcal {L}}_I\)” indicates the absence of augmented mutual information, and “Ours” corresponds to the final model.

Table 8 Ablation experiments of different components on the MNIST and USPS dataset

Finally, we evaluated the model’s robustness to outliers. Following the approach described in FSVAE [58], we conducted experiments by introducing outliers. We compared our model against several pre-trained deep generative clustering methods, including VaDE, Dual-AAE, and VaGAN, as well as a selection of non-pre-trained deep generative clustering methods, including VaDE, Dual-AAE, VaGAN, GMVAE, S-VaDE, and S3VDC. Additionally, we compared our results with the baseline FSVAE to highlight the enhanced robustness of our model. Table 9 presents the ACC values on the MNIST test dataset with 5%, 10%, and 15% outliers. We observed that our model exhibits superior robustness compared to all non-pre-trained models, surpasses certain pre-trained models, and outperforms the baseline model FSVAE in terms of robustness.

Table 9 The ACC performance with different outlier ratios on the MNIST-test dataset
Table 10 The Running time results by different methods on the MNIST dataset

The computational complexity of our proposed method is \(O(m\cdot n\cdot s)\), where m represents the number of epochs, n quantifies the ratio of the total data volume to the batch size, and s signifies the execution time of the encoder and decoder neural networks. As illustrated in Table 10, a comparative analysis of the epoch-wise runtime between our model and existing models is presented. To ensure a fair comparison, we standardized the encoder and decoder across all models. A distinctive advantage of our model is its elimination of the need for pre-training. This attribute significantly contributes to its efficiency, particularly in achieving a reduced runtime when compared with other methods.

6 Conclusion

In this work, we introduce a novel generative clustering approach based on FSVAE aimed at significantly enhancing clustering performance. Our methodology encompasses several key advancements. First, we optimize clustering efficiency by maximizing the cluster assignment variables for both original and augmented data. Second, by integrating augmented data and cluster assignment variables, we impose more rigorous constraints on the latent variables, thereby achieving improved clustering results. Finally, our approach includes the innovative creation of new latent variables, which are then utilized in the reconstruction process to further boost the clustering effect. Comparative experimental results affirm that our model outperforms existing methods, demonstrating its superior capabilities in clustering tasks.