1 Introduction

Multi-label text classification (MLTC) focuses on assigning one or multiple class labels to a document given the candidate label set. It has been applied to many fields such as tag recommendation [7], sentiment analysis [8], text tagging on social medias [18]. It differs from multi-class text classification, which aims to predict one of a few exclusive labels for a document [6].

Two types of information should be captured for the MLTC task. One is intra-class information, which cares the data distribution of samples belonging to the same category. The other is inter-class information, which models relationships between classes, such as label co-occurrence and hierarchy.

Prior efforts for multi-label text classification mainly focus on learning enhanced text representation [1, 13, 20, 22]. These models feed the text representation into a set of linear classifiers. Each linear classifier predicts whether the given document belongs to a certain class. During training, the linear classifiers capture the intra-class information by learning the decision boundaries of corresponding classes. However, these methods neglect the inter-class information since the linear classifiers are trained independently and never interact with each other.

Recently, extracting the inter-class information has raised researchers’ attention [15, 16, 19, 23]. Some studies construct a label graph according to the inter-class information, and convert the graph into node features via random walk-based node embedding methods [23] or graph neural network (GNN) [15, 16, 19]. The probability that a document belongs to a class is calculated by the dot product of document features and corresponding node features. These methods capture the inter-class information while depreciating the expressiveness of intra-class information. For node embedding-based methods, node embeddings are optimized in advance with the objective function of reconstructing neighbors. The optimized node embeddings, which take information for reconstruction rather than text classification, occupy the limited capacity originally used for modeling intra-class information. For GNN-based methods, the message passage process harms the expressiveness of intra-class information because the decision boundaries of classes receive noises from other nodes.

In this paper, we propose Aggregating Intra-class and Inter-class information Framework (AIIF) for MLTC. AIIF consists of a text encoder and a two-branch classification layer. On the classification layer, the linear branch applies multiple linear classifiers to capture intra-class information. The graph-assisted branch employs graph neural networks to a label-graph, where the message passing process captures the inter-class information. Each branch takes the text feature as input and makes predictions independently. Two branches’ predictions are aggregated by a followed fusion module, which is optimized during the training process. With a divide-and-conquer architecture, AIIF captures both intra- and inter-class information and prevents the modeling of intra- and inter-class information from interfering with each other. Besides, AIIF supports plug-and-play usage, i.e., existing studies focusing on enhanced text representation or extracting inter-class information can be coupled with AIIF by serving as the text encoder or graph-assisted branch.

To evaluate the effectiveness of AIIF, we implement an instance of AIIF with BERT [5] and GCN [9], then evaluate the instance on widely used RCV1 and AAPD datasets. Experimental results show that the instance outperforms its variants without the two-branch classifier by a large margin. Besides, the instance achieves state-of-the-art results on the widely used RCV1 dataset and achieves competitive scores on the AAPD dataset.

The main contributions of this paper are listed as follows:

  • We propose AIIF, a novel MLTC framework which can capture both intra-class and inter-class information.

  • We implement an instance of AIIF. Experimental results show that the instance outperforms the baselines and gets competitive results on two public MLTC datasets.

  • To our best knowledge, We firstly analyze MLTC from the view of intra- and inter-class information. We hope this work can provide a new perspective to the community.

2 Related Work

As mentioned in Sect. 1, existing MLTC work focuses on two directions: improving text representation and extracting the inter-class information.

To obtain a good text representation, many neural models have been applied, such as CNN [13], RNN [14, 20, 22], the combination of CNN and RNN [10], and BERT [1, 3]. Some models consider improving text representation with the interaction between the input document and labels [6, 20, 22]. In addition, some methods construct a graph of words or documents to capture non-consecutive and long-distance semantics within a document or the whole corpus [17, 21]. We argue that these methods neglect the inter-class information.

Capturing the inter-class information has attracted much attention in recent years. The main idea is modeling relationships between labels as graphs and guiding the multi-label prediction with graph representation. For example, Zhang et al. [23] construct a category graph according to the label correlations and used a random-walk-based method to encode the graph. Most subsequent work applies neural networks to encode label graphs, such as Tree-LSTM [24] and the variants of GCN [15, 16, 24]. Lu et al. [15] propose aggregating knowledge from multiple label graphs. Ma et al. [16] propose predicting the label graph according to each document. However, these methods sometimes perform worse than their variants without label graphs [2, 15], which may be attributed to the reason that these methods ignore intra-class information.

Above methods only focus on either intra-class or inter-class information. Compared with them, AIIF applies a two-branch architecture to capture both information.

3 Methods

As shown in Fig. 1, AIIF separates the modeling of intra- and inter-class information with a two-branch classification layer. The classification layer takes the representation of the input document, which is obtained by the text encoder, as input. The linear branch captures intra-class information with a set of linear binary classifiers. The graph-assisted classifier branch models inter-class information by first encoding the label graph as node features, then calculating the dot product between node features and text features. The fusion module combines the predictions of the linear and graph-assisted branches as the probability that the input document belongs to each category.

Fig. 1.
figure 1

The overall architecture of AIIF.

Problem Formulation. Given the label space \(\mathcal {L} = \{\boldsymbol{l}_1, \boldsymbol{l}_2, \ldots , \boldsymbol{l}_T\}\) and an input document \(\boldsymbol{x}\), MLTC aims at predicting a label set \(\boldsymbol{L_{pred}} \subseteq \mathcal {L}\) for \(\boldsymbol{x}\).

In the remaining part of this section, we first describe the working flow of AIIF when BERT serves as the text encoder and GCN serves as the graph encoder. Then we introduce the training of AIIF . Finally, we introduce the method of constructing the label graph.

3.1 AIIF Models

Text Encoder. Basically, given an input document \(\boldsymbol{x} = \{\boldsymbol{x}_1, \boldsymbol{x}_2, \ldots , \boldsymbol{x}_n\}\), where \(\boldsymbol{x}_i\) is the i-th token in the document, BERT converts \(\boldsymbol{x}\) to the text feature as

$$\begin{aligned} \boldsymbol{H_t} = BERT(\boldsymbol{x})~, \end{aligned}$$
(1)

where \(\boldsymbol{H_t} \in \mathcal {R}^{d_t}\) are the text feature. As recommended in [5], we add a special “[CLS]” token in front of \(\boldsymbol{x}\) before feeding \(\boldsymbol{x}\) into the BERT model, and take the token feature of the “[CLS]” token produced by the BERT model as the text feature of the input document \(\boldsymbol{x}\).

Graph-assisted Classifier. Given the category graph \(\mathbb {G}\), we obtain its node features via a graph convolutional network (GCN). We choose GCN for two reasons: First, GCN can capture the complex topology structure between labels through message passing. Second, GCN can be jointly trained with the other parts of the model in an end-to-end manner to achieve the optimal solution. We denote every convolutional network layer as a non-linear function \(f(\,,\,)\), which takes the adjacency matrix \(\boldsymbol{A'}\) and node features \(\boldsymbol{H_g}^l \in \mathcal {R}^{T \times d_g^l}\) as input. Here, T is the number of nodes (labels) and \(d_g^l\) represents the dimension of node features.

$$\begin{aligned} \begin{aligned} \boldsymbol{H_g}^{l+1}&= f(\boldsymbol{H_g}^l, \boldsymbol{A'}) \\&= \sigma (\boldsymbol{A'}\boldsymbol{H_g}^l \boldsymbol{W}^l) \\ \end{aligned} \end{aligned}$$
(2)

Here, \(\boldsymbol{W}^{l} \in \mathcal {R}^{d_g\times d_g}\) is a weight matrix to be learned, \(\sigma (\cdot )\) indicates a non-linear function, which is implemented with ReLU in this paper.

We treat the output of the last convolutional layer as node features, denoted by \(\boldsymbol{L_G}\). The graph-assisted classifier produces the prediction for each label according to the dot product between text features and node features.

$$\begin{aligned} \begin{aligned} f^g(\boldsymbol{H_t})&= \boldsymbol{H_t} \boldsymbol{W^g} \\ z^g_i&= {\boldsymbol{L_G}}_i \cdot f^g(\boldsymbol{H_t})_i~ \end{aligned} \end{aligned}$$
(3)

where \(f^g()\) is a linear transformation with \(\boldsymbol{W^g} \in \mathbb {R}^{d_t \times d_g}\). The transformation is required to match the dimension to that of \(L_G\).

Linear Classifiers and Fusion Module. The linear classifiers make prediction for each label as,

$$\begin{aligned} \begin{aligned} f^p(\boldsymbol{H_t})&= \boldsymbol{H_t} \boldsymbol{W^p} \\ z^w_i&= f_p(\boldsymbol{H_t}) L_W \end{aligned} \end{aligned}$$
(4)

where \(f^p()\) is a linear transformation with \(\boldsymbol{W^p} \in \mathbb {R}^{d_t \times d_p}\)., \(L_W\) is the weight of linear classifiers.

The fusion module aggregates predictions of the linear classifiers \(z^w\) and that of the graph-assisted classifier \(z^g\) by weighted summation as follows,

$$\begin{aligned} z = z^g \cdot \mu + z^w \cdot (1-\mu ) \end{aligned}$$
(5)

where \(\mu \in \mathbb {R}^T\) is trainable parameters representing the ratio of inter-class relationship for each category.

3.2 AIIF Training

We adopt a 2-stage training for AIIF : (1) Training the text encoder (2) Training the classification layer.

Training the Text Encoder. In the first stage, we train the text encoder according to supervised signals from the dataset to obtain well-learned text features. Specifically, we add a linear classifier after the text encoder and compare predictions of the linear classifier with ground-truth labels with a hinge loss.

$$\begin{aligned} SH(\boldsymbol{y}, \hat{\boldsymbol{y}})=\sum _{i=1}^T (max(0, 1 - \boldsymbol{y}_i \cdot \hat{\boldsymbol{y}}_i)^2)~ \end{aligned}$$
(6)

where \(\hat{\boldsymbol{y}} \in \mathcal {R}^{T}\) is the prediction of the linear classifier. After training, only the parameters in the text encoder are saved for later use.

Training the Classification Layer. In the second stage, we freeze the parameters of the text encoder. The remaining parts of AIIF are randomly initialized and trained. Here, we use the binary cross-entropy (BCE) loss to train the model. Given the ground-truth label vector \(\boldsymbol{y}\) and a vector of predicted probability p, the BCE loss is calculated as

$$\begin{aligned} BCE(p, \boldsymbol{y}) =\sum _{i=1}^{T} (\boldsymbol{y}_{i} \log p_{i}+\left( 1-\boldsymbol{y}_{i} \right) \log \left( 1-p_{i}\right) ) \end{aligned}$$
(7)

We apply BCE loss to the linear classifiers, the graph-assisted classifier and the whole model. The final loss L is calculated as

$$\begin{aligned} \begin{aligned} L_f&=BCE(sigmoid(z), \boldsymbol{y})\\ L_w&=BCE(sigmoid(z_w), \boldsymbol{y})\\ L_g&=BCE(sigmoid(z_g), \boldsymbol{y})\\ L&=L_f + \alpha L_w + \beta L_g, \end{aligned} \end{aligned}$$
(8)

where \(\alpha \) and \(\beta \) are hyper-parameters used for balancing these losses.

For the initialization of the vertices embedding matrix, one method is using the mean-pooling of word embeddings of the tokens in the text of the label. However, some MLTC datasets does not provide label text or provide it in the form of abbreviation, which prevents us from obtaining a good initial vertices embedding matrix. Thus, we initialize the embedding of a category label according to the documents belong to the category. More specifically, if a set of documents \(\{\boldsymbol{x}_1, \boldsymbol{x}_2, \ldots , \boldsymbol{x}_k\}\) have the ground-truth label \(\boldsymbol{l}\), then the initial vertex embedding of \(\boldsymbol{l}\) is

$$\begin{aligned} \boldsymbol{H^0_g} = \frac{1}{k} \sum _{j=1}^k BERT(x_j)~, \end{aligned}$$
(9)

3.3 Label Graph Construction

Following [4], we create the label graphs according to the co-occurrence patterns between labels within the dataset. Details are as follows.

First, we count the co-occurrence of label pairs \((\boldsymbol{l}_i, \boldsymbol{l}_j)\) in the training set to obtain the label correlation matrix \(\boldsymbol{M} \in \mathbb {R}^{T \times T}\). Then, we calculate the conditional probability \(P(\boldsymbol{l}_j | \boldsymbol{l}_i)\) that \(\boldsymbol{l}_j\) appears when \(\boldsymbol{l}_i\) appears as

$$\begin{aligned} P(\boldsymbol{l}_j | \boldsymbol{l}_i) = M_{ij} / {N_i}, \end{aligned}$$
(10)

where \(N_i\) denotes the appearance frequency of \(\boldsymbol{l}_i\) in the training set.

Then, we apply a threshold \(\tau \) to filter out the noisy rare co-occurrence via

$$\begin{aligned} \boldsymbol{A}_{ij} = \left\{ \begin{array}{lr} 0, \ \ \text {if} \ P(\boldsymbol{l}_j | \boldsymbol{l}_i)\ <\ \tau \\ 1, \ \ \text {otherwise} \end{array} \right. ~, \end{aligned}$$
(11)

where \(\boldsymbol{A}\) is the binary adjacency matrix. However, directly applying \(\boldsymbol{A}\) to GCN may cause the over-smoothing problem [12]. To alleviate the issue, we re-weight \(\boldsymbol{A}\) as follows to obtain the final adjacency matrix \(\boldsymbol{A'}\).

$$\begin{aligned} \boldsymbol{A'}_{ij} = \left\{ \begin{array}{lr} p/\sum _{j=1, i\ne j}^C A_{ij},\ \ if\ i \ne j, \\ 1 - p,\ \ if\ i=j \end{array} \right. \end{aligned}$$
(12)

4 Experiments

4.1 Datasets and Evaluations

Table 1. Summary of the datasets. N is the number of samples in the training and validation set, M is the size of the testing set, W denotes the average length of documents, and \(\tilde{C}\) means the average number of labels per sample.

We perform experiments on two widely-used MLTC datasets: RCV1 [11] and AAPD. For a fair comparison, we follow the dataset split used in previous work [20]. The statistics of datasets are shown in Table 1.

Following the setting of previous work [16, 20], we apply two metrics for performance evaluation: precision at top k (P@k), the normalized discounted cumulated gain at top k (nDCG@k). Given the ground-truth binary vector \(\boldsymbol{y} \in \{0, 1\}^T\), P@k is defined as follows:

$$\begin{aligned} \begin{aligned} P@k = \frac{1}{k} \sum _{l=1}^k \boldsymbol{y}_{rank(l)} \end{aligned} \end{aligned}$$
(13)

where rank(l) is the index of the l-th highest predicted label. nDCG@k is defined as follows:

$$\begin{aligned} \begin{aligned} DCG@k&= \sum _{l=1}^k \frac{\boldsymbol{y}_{rank(l)}}{log(l+1)} \\ iDCG@k&= \sum _{l=1}^{min(k, ||\boldsymbol{y}||_0)} \frac{1}{log(l+1)} \\ N@k&= \frac{DCG@k}{iDCG@k} \end{aligned} \end{aligned}$$
(14)

4.2 Baselines

We select the following methods as baselines. (1) DXML Zhang et al. [23] construct a graph of labels considering the co-occurrence between labels and applied a random walk-based method to obtain node features. (2) AttentionXML You et al. [22] adopt a multi-label attention mechanism to perform hierarchical classification. (3) LSAN Xiao et al. [20] use the attention mechanism to consider the relations between document words and labels. (4) LGAN Ma et al. [16] use GCN to encode a static and dynamic text-specific label graph for predictions of each text. It achieved current state-of-the-arts results on RCV1 and AAPD datasets. We report the results of baselines from their original paper if no extra description.

4.3 Implementation Details

We apply BERT [5] as the text encoder because BERT has the ability to obtain strong contextualized text representation and has achieved great success in many NLP tasks. The pre-trained BERT utilized in this paper are provided by transformersFootnote 1 library. We use the \(\text {BERT}_{base}\) checkpoint. We set \(d_t=768\) and \(d_g=d_p=256\). When constructing the label graph, we use \(\tau = 0.5\) and \(p = 0.2\).

For all training stages, we use the Adam optimizer. The initial learning rate lr is \(5 \times 10^{-4}\) expect for fine-tuning BERT in the first training stage, where lr is \(5 \times 10^{-5}\). We use learning rate warm-up during the first 0.1 proportion of the whole training process, and a linear learning rate decay is applied for the remaining process. We use early stopping and the max training epoch for each stage is 20. We set \(\alpha =1\) and \(\beta =1\). Each document is truncated at the length of 350 and 250 for the RCV1 and the AAPD dataset, respectively.

Table 2. Compare AIIF with previous methods. \(^\dag \) indicates that the scores are collected from [20]. Best results are shown in bold.

4.4 Main Results

We compare AIIF with previous methods. Results are shown in Table 2. From the results, we can observe that

  1. 1.

    AIIF significantly outperforms BERT. AIIF outperforms BERT in all metrics in both datasets, with the improvement between \(0.70\%\) and \(1.15\%\) in RCV1 dataset and between \(0.23\%\) and \(1.44\%\) in AAPD dataset. The superiority of AIIF over BERT is that AIIF capture both intra- and inter-class information, not just intra-label information.

  2. 2.

    Competitive results on two datasets. Compared with the previous state-of-the-art method LDGN, AIIF outperforms LDGN in all metrics on the RCV1-v2 dataset, with an improvement between \(0.27\%\) and \(2.35\%\); on the AAPD dataset, AIIF achieved close results with LDGN, with an improvement between \(-1.20\%\) and \(0.73\%\) in each metric.

4.5 Ablation Study

Table 3. AIIF compares with its variants on MLTC datasets. AIIF-L represents the model consists of the text encoder and linear classifiers. AIIF-G represents the model consists of the text encoder and graph-assisted classifier. The best results are shown in bold.

To further analyze the effectiveness of the two-branch architecture, we compare AIIF to its two variants: (1) AIIF-L: We remove the graph-assisted classifier and the fusion module from the AIIF. The linear classifiers’ predictions are treated as the final prediction. (2) AIIF-G. We remove the linear classifiers and the fusion module from the AIIF. The graph-assisted classifier’s predictions are treated as the final prediction. AIIF, AIIF-L, and AIIF-G follow the same training methods.

The results are shown in Table 3. We can observe that in most cases, removing any branch of AIIF will cause the performance drop on two datasets. Take the P@1 score on the AAPD dataset as an example, AIIF-L is inferior to AIIF by \(0.81\%\), and AIIF-G is inferior to AIIF by \(1.27\%\). The performance drop demonstrates the effectiveness of the proposed two-branch classifier.

4.6 Performance on the Tail Labels

As mentioned in Sect. 1, previous work shows that encoding the inter-class information achieves promising results on tail labels [19]. We are interested in whether introducing the intra-class information further improves results on tail labels. Thus, we evaluate AIIF and its variants with propensity scored precision at k (PSP@k), which is calculated as

$$\begin{aligned} PSP@k = \frac{1}{k} \sum _{l=1}^k \frac{y_{rank}(l)}{P_{rank}(l)} \end{aligned}$$
(15)

Results are shown in Fig. 2. We can observe that AIIF outperforms AIIF-G on all datasets, demonstrating that even if a model has captured the inter-class information, introducing the intra-class information still improves performance on tail labels.

Fig. 2.
figure 2

Performance on tail labels

4.7 Case Study

Further, we compare the prediction results of AIIF and BERT to analyze why AIIF is superior to BERT in Table 4.

Table 4. A case study on the RCV1 dataset. Here, we show the top 5 predictions of BERT and AIIF, and the right predictions are colored by red.

For the example shown in Table 4, BERT correctly predicts the categories “government/social,” “expenditure/revenue,” and “welfare, social services,” but not the categories “economics” and “government finance”. AIIF predicts all the correct categories. We believe the phenomenon can be attributed to AIIF extracting the relationship between the categories after introducing the category map. In the training set, the categories “economics” and “government finance” have a strong co-occurrence with the three categories correctly predicted by BERT. Correspondingly, they are connected by edges in the category graph. Extracting such inter-class information increases the probability of predicting “economics” and “government finance”. Conversely, the two categories that BERT incorrectly predicted, “health” and “domestic politics”, have weaker co-occurrence with the three categories that BERT correctly predicted. Hence, AIIF excludes these two incorrect predictions.

5 Conclusion

This paper studies the multi-label text classification task. Previous methods focus either intra-class or inter-class information. We propose a novel two-branch architecture to combine both information. Experimental results show that the model capture both intra-class and inter-class outperforms those modeling either of them.