1 Introduction

Image classification is an important research direction in the field of computer vision [1,2,3]. Its main task is to automatically identify the target in the image by computer and assign it to the corresponding category set. Traditional image classification algorithms mainly include three steps: preprocessing, feature extraction and classification. Feature extraction is the most important step in image classification, and the quality of features extracted in this step directly affects the performance of image classification. The features extracted by traditional image classification algorithms are relatively redundant and do not have good generalization, so the researchers have proposed many excellent neural network models [4,5,6].

At present, due to its unique advantages in image processing, convolutional neural network (CNN) is widely used in image classification tasks. However, due to the multi-layer convolution and pooling operation, a large amount of important information is lost, which leads to the insufficient expression ability of the features extracted by the convolutional neural network. To solve these problems, researchers propose a series of feature enhancement modules to improve the robustness of the features. However, most of the feature enhancement modules not only increase the computational load of the original neural network model, but also fail to improve the accuracy.

From the perspective of feature extraction and selection, the existing image classification methods are usually based on two types of feature learning frameworks: the Bag of Words (BOW) model and the deep learning model.

The feature representation of the traditional BOW model [7] is to generate a global representation for the image, and the generation process mainly consists of three steps: manual feature extraction, feature encoding and feature aggregation. First, features are extracted manually on dense grids or sparse points of interest. Then the features are quantified through different coding schemes. Finally, these coded features are combined by feature aggregation to form image-level representation. The common models for BOW feature classification include Support Vector Machine (SVM) and random forest [8, 9]. In addition, the BOW feature learning method combined with the probabilistic theme model also achieves good results. Jin [10] represented the image features as dictionaries through the BOW model, and used the Probabilistic Latent Semantic analysis (PLSA) model to find out the potential themes from a large number of images to classify the images. Ghorai [11] assumed that the potential topic space could be learned from the two modes of vision and text. PLSA was used to learn the data in different models to obtain the corresponding semantic topic distribution, and then the two models were fused through adaptive asymmetric algorithm to obtain better classification effect. Then, Filisbino [12] proposed a learning method combining mixed generation and discriminant model. It used continuous PLSA to model the visual features of images to reduce the impact of cluster granularity on classification performance, and adopted an integrated classifier chain to classify multi-labeled images. However, in the specific task of image classification, manual features are not always the best result.

In the computer vision field, Teng [13] had proved that CNN performed better than many other methods based on manual features. Li [14] proposed that the image representation learned by CNN on large-scale data sets could be effectively transferred to other visual recognition tasks based on limited training data. Yan [15] proposed an integration framework for classification, named Over-Feat (OF), which achieved better precision than the traditional BOW model. Chaib [16] also used a pre-trained CNN model to extract image features, but used an integrated classifier chain instead of SVM to classify depth features, so that they could learn the association between marked data sets to obtain better performance. Zhu et al. [17] proposed spatial regularization network (SRN), which was used to generate the whole labeled attention graph and capture the underlying relationship between them. Then, a ResNet-101 network was adopted to aggregate the regularized classification result and the original result. Mou [18] developed a circular memory attention module to realize the explicable image classification including two alternate components. You [19] proposed a multi-marked classification model based on the graph convolutional network to establish a directed graph on the objects, in which each node was represented by the marked embedded word. The marked graph was mapped into a group of interdependent object classifiers by training the graph convolutional network.

Deep neural network (DNN) uses an architecture composed of several nonlinear transformations to model the high-level abstraction of visual data and shows good effectiveness in image classification tasks. Teng [20] compared several traditional feature learning methods and used CNN to solve the problem of image classification based on several loss functions. The results showed that the classification performance of CNN was significantly higher than that of the traditional method. To get an effective CNN model, the CNN needs to learn a lot of parameters in training. However, it is very difficult to train a CNN for specific tasks on the limited training data sets. Therefore, using parameter transfer learning to optimize the CNN model has become a widely advanced method. Many works had proved that the parameters in the pre-trained CNN model on diversity ImageNet can be transferred to the new model and extract features of other data sets without sufficient training samples. In other words, parameter transfer learning in training phase is conducive to increasing the data set, which can further optimize the model, solve the problem of sample shortage in disguise and obtain more accurate features. Some researchers adopt semi-supervised learning methods for image classification.

The main idea of semi-supervised learning is to use the information provided by marked data and unmarked data, improve the learning performance in the case of little marked data. The marked data provides data and label joint distribution information, while unmarked data only provides data distribution information. In reference [21], the self-training-based semi-supervised learning method was used for target detection, which could obtain the same effect as the traditional training model with a larger labeled data set. In reference [22], a ladder network method constructed by the self-encoding network was proposed. By adding horizontal connection, noise encoder and corresponding de-noising decoder were added into the normal feed-forward network to learn the unlabeled training data. Zhao et al., [23] chose FCN, ResNet, and PSPNet as classifiers. The models were trained by different proportions of training samples from Jingjinji region. Then it used the trained models to predict the results of the study areas. Wang et al., [24] proposed a weakly supervised deep learning framework with uncertainty estimation to address the macula-related disease classification problem from OCT images with the only volume-level label being available. At present, the semi-supervised learning methods mainly include generative model, self-training model, semi-supervised SVM, entropy regularization model and graph-based model [23], etc,. The existing researches results show that compared with supervised learning with labeled data, the learning performance of semi-supervised learning is modified.

In this paper, a semi-supervised image classification method based on modified GAN is proposed. By using a little labeled data and unlabeled data, the discriminator network of GAN can output data category labels, and a semi-supervised classification method based on little labeled samples is realized. Comparing to the supervised learning methods, the performance of the proposed GAN is better than that of other semi-supervised networks. This method can be applied to image classification, medical diagnosis, anomaly detection and image recognition, etc,.

This paper is organized as follows. In Sect. 2, we detailed elaborate the proposed GAN. The experiments for demonstrating the effectiveness of this new image classification method are conducted in Sect. 3. There is a conclusion in Sect. 4.

2 Proposed GAN

2.1 Generative adversarial networks (GAN)

GAN consists of a generator network G and a discriminator network D. Figure 1 is the process of GAN.

Fig. 1
figure 1

Process of GAN

Generator G maps the random noise \(z\) conforming to a specific distribution \(P_{z}\)(such as Gaussian distribution, uniform distribution, etc.) to the target domain, and uses it to learn the probability distribution \(P_{data}\) of real data, so that it can make it generate a sample \(G(z)\) that conforms to the real data distribution \(P_{data} (x)\) as far as possible. The discriminator D determines whether the input sample comes from the real data \(x\) or the generated data \(G(z)\), and outputs a probability value \(D( \cdot )\) belonging to the real data.

The goal of the generator is to fit "true" data (training samples) and generate "false" data. The goal of the discriminator is to distinguish the true and false data. The network structure of the generator and discriminator is a multi-layer perceptron. Given a real sample set \(\{ x^{1} , \cdots ,x^{n} \}\). Suppose that \(p_{x}\) is its data distribution. It randomly samples data from another pre-defined distribution \(p_{z}\) to obtain noise set \(\{ z^{1} , \cdots ,z^{m} \}\). Let the input of the generator be \(z\). The output "false" data can be represented as \(G(z)\). The input of the discriminator is "true" and "false" data in turn. The output is a one-dimensional scalar, representing the true probability of input. According to the difference of input, they can be expressed as: \(D(x)\) and \(D(G(z))\). Ideally, \(D(x)\) = 1, \(D(G(z))\) = 0. The network optimization process can be described as a “binary minimax” problem. The objective function is as follows:

$$ \mathop {\min }\limits_{G} \mathop {\max }\limits_{D} E_{{x\sim p_{x} }} \ln D(x) + E_{{z\sim p_{x} }} \ln (1 - D(G(z))) $$
(1)

If the the data distribution of \(G(z)\) is expressed as \(p_{G}\), so there is a global optimal solution for the “binary minimax” problem, namely the \(p_{G}\) = \(p_{x}\). The generator and discriminator are trained alternately. When updating parameters of generator(discriminator), the parameter of other one is fixed and not updated. In general, the discriminator is better learning ability than the generator. To keep the two in sync, it trains the generator \(k\) times and trains the discriminator one more. Through experiments, it is found that the learning ability of generator and discriminator changes with time. Therefore, in the subsequent experiments, this paper designs a dynamic learning method to keep the two in sync by observing the changes of loss function values.

2.2 Modified GAN

Generator and discriminator are in a "confrontation" relationship. The ultimate goal is to enable the generator to perfectly fit the data distribution of real samples. Due to the lack of guidance from supervisory information, the fitting process is full of randomness. In practice, limited by the learning ability of the network, it is usually able to fit only a part of the real data distribution, leading to the loss of some modes, namely Mode collapse. As shown in Fig. 2, mode collapse will lead to redundancy of training results, poor image quality and other problems. Through the analysis of real data, it is not difficult to find that there are significant differences between different patterns. For example, men and women in human faces, day and night in scenes, etc., also have connections, such as five-feature structure, object shape and position, etc., which emphasizes differences while ignores connections. The key to solving the problem is to find a balance between them.

Fig. 2
figure 2

Mode collapse in GAN. a synthetic data distribution cannot fit real data distribution in good shape; b mode collapse leads to synthetic data redundancy

Thus, this paper designs the network structure as shown in Fig. 3. The training is synchronized by building two (or more) generators, sharing one input data and one discriminator. The training method is the same as the classic GAN. In addition, generators learn from each other, this step is called "collaboration," which guides each other and makes progress together. The "collaboration" is interspersed with normal training, and the rate can be adjusted according to the actual situation. For example, it trains the generator two times and collaborates one time. From the perspective of data distribution as shown in Fig. 4, classical GAN training can shorten the distance between the real distribution and the generated distribution. Collaborative training can shorten the distance between different generator generation distributions. This approach can not only improve the convergence speed of the model, but also enhance the learning ability of the model and reduce the possibility of mode collapse.

Fig. 3
figure 3

Proposed GAN

Fig. 4
figure 4

Fitting process in proposed network

Because the input and discriminator networks are shared between generators, the generator distribution may appear to overlap phenomenon. This not only fails to achieve the desired goal, but also creates additional network load. In order to avoid this phenomenon, different network structures and random weight initialization are adopted in designing the generator. Overlap problem does not occur in the actual training process, and the results produced by different generators are always different to some extent. The objective function of the discriminator is:

$$ \max E_{{x\sim p_{x} }} \ln D(x) + E_{{z\sim p_{z} }} \ln (1 - D(G_{1} (z))) + E_{{z\sim p_{z} }} \ln (1 - D(G_{2} (z))) $$
(2)

For the generator, \(E_{{x\sim p_{x} }} \ln D(x)\) is unaffected, so its objective function is:

$$ \max E_{{z\sim p_{z} }} \ln D(G_{1} (z)) + E_{{z\sim p_{z} }} \ln D(G_{2} (z)) + \lambda L $$
(3)
$$ L = - ||G_{1} (z) - G_{2} (z)||_{2} $$
(4)

where \(\lambda\) is the constant. The collaboration factor L selects L2-norm to shorten the distance between the generators. \(D(G_{1} (z))\) and \(D(G_{2} (z))\) are the discriminant results of generated data by generator G1 and G2, respectively. Defining the parameters,

$$ s = D(G_{1} (z)) - D(G_{2} (z)) $$
(5)

When s > 0, G1 has a higher score in the results obtained by discriminator D, that is, \(G_{1} (z)\) has a higher image truth degree. So the distance between G2 and G1 should be shortened, which can be done by fixing G1 parameter, calculating the collaboration factor L, and punishing the network connection weight of G2. When s < 0, it is completely different. G2 should be fixed and G1 should be punished. The severity of punishment is related to the size of s. In this way, the generator with a higher score is judged to have an attractive force on the generator with a lower score. Due to the randomness of the network, G1 and G2 are trained alternately and assisted each other. Finally, they are converged to the real data distribution. To sum up, such a network structure is called "cooperative GAN".

3 Experiments and Analysis

In this paper, the proposed GAM is evaluated on two common benchmarks: Cifar and PASCAL VOC 2012. Comparison experiments are conducted under Linux16.04 operating system, CPU (32cores, 2.1 GHz) and GPU TTAN Xp 1060. Pytorch is the deep learning framework.

In this section, Cifar10 and Cifar100 data sets are used to evaluate the effectiveness of the proposed algorithm. Each data set contains 60,000 color images with size of 32 × 32, 50,000 images are used for training and 10,000 for testing. Data enhancement is done using the usual methods: clipping and random horizontal flipping.

In this section, Cifar10 and Cifar100 data sets are used to evaluate the effectiveness of the proposed algorithm. Each data set contains 60,000 color images with size of 32 × 32, 50,000 images are used for training and 10,000 for testing. Data enhancement is done using the usual methods: clipping and random horizontal flipping.

To ensure the fairness of the experiments, the Stochastic Gradient Descent (SGD) method [26] is adopted for the experiments. Momentum is set as 0.9. The training batch is set as 128. The testing batch is set as 100. The weight attenuation is set as 0.0005. The learning rate is 0.1. The final result is the five-testing average value. The comparison methods are SR [27], CWA [28], JFIS [29]. Table 1 shows the average classification accuracy of the four algorithms with different labeled samples. Figure 5 shows the trend of average classification accuracy with different labeled samples.

Table 1 Average classification accuracy with different labeled samples
Fig. 5
figure 5

The trend of average classification accuracy with different labeled samples

It can be seen from the above results that proposed GAN can achieve the same classification performance as JFIS with only a few labeled data. Proposed GAN significantly improves the classification performance and is better than the other two models. Table 2 shows the results of Top-1 error rate.

Table 2 Top-1 error rate with different methods

It can be seen from the data in the last column in Table 2 that the proposed GAN in this paper has the lowest Top-1 error rate comparing to other methods. Proposed can effectively improve the classification performance of the network by adding fewer parameters, which shows that the generalization of the module in this paper is good.

Table 3 and Table 4 are the classification results on the Cifar and PASCAL VOC 2012 data set with the four methods, which also show the better performance with the proposed GAN.

Table 3 Results on Cifar data set
Table 4 Results on PASCAL VOC 2012 data set

Table 5 displays the average accuracy, average error, running time values. As can be seen from Table 5, the average accuracy, average error, running time of proposed GAN are 98.2%, 2.7% and 0.56 s respectively, which is higher than that of SR, CWA, JHIS.

Table 5 Comparison of average accuracy, average error, running time value

We also give the image classification sample as shown in Table 6. It can be seen that the classification results of the proposed model in this paper are better than other methods in most cases. The “sofa” in the second image, the “Dining table” in the fourth image and the “Chair” in the fifth image have occlusion phenomenon, the proposed GAN can correctly recognize these objects, which indicates that the proposed GAN has stronger robustness for the recognition of occlusion objects. Additionally, the proposed GAN recognizes the small object “bottle” in the third image indicating that the proposed GAN improved the ability of recognizing small objects. By further comparison, it can be found that for the first image, the “dining table” predicted by the new model is closer to the original semantics of the image than the artificially labeled “bicycle”. It shows that even if the classification of the model is not consistent with the manual label, it can reflect the image semantics correctly to some extent.

Table 6 Image classification sample

4 Conclusions

In this paper, an improved GAN model is proposed to solve the problem that the feature representation ability extracted by the existing neural network model is insufficient and the image classification and recognition accuracy is not high. By constructing multiple generators and introducing cooperative mechanism, they learn from each other and make progress together. The image classification quality can be significantly improved, network convergence speed can be accelerated. It improves the learning efficiency and reduces the possibility of mode collapse. The experiment results on the open public data set show that the new GAN model has better performance in image classification than other advanced models. In the future, more advanced deep learning methods will be applied in image classification.