Keywords

1 Introduction

Recent work in machine learning and computer vision have demonstrated advantages of integrating human attention with artificial neural network models, as studies show that many machine vision tasks, i.e., image segmentation, image captioning, object recognition, etc., can benefit from adding human visual attention [36].

Visual attention is the ability inherited in biological visual systems to selectively recognize regions or features on scenes relevant to a specific task [3], where “bottom-up” attention (also called exogenous attention) focuses on physical properties in the visual input that are salient and distinguishable, and “top-down” attention (also called endogenous attention) generally refers to mental strategies adopted by the visual systems to accomplish the intended visual tasks [44]. Early research on saliency prediction aims to understand attentions triggered by visual features and patterns, and thus “bottom-up” attention is the research focus [3]. More recent attempts, empowered by interdisciplinary efforts, start to study both “bottom-up” and “top-down” attentions, and therefore the terms, saliency prediction and visual attention prediction, are used interchangeably [53]. In this paper, we use the term saliency prediction as the prediction of human visual attentions allocations when viewing 2D images, containing both “bottom-up” and “top-down” attentions. 2D heatmap is usually used to represent human visual attention distribution. Note that saliency prediction studied in this paper is different from neural network’s saliency/attention which can be visualized through class activation mapping (CAM) by [63] and other methods [15, 48, 51]. With the establishment of several benchmark datasets, data driven approaches demonstrated major advancements in saliency prediction (review in [2] and [60]). However, saliency prediction for natural scenes is the primary focus, and more needs to be done in the medical domain. Hence, we intend to study the saliency prediction for examining chest X-ray (CXR) images, one of the most common radiology tasks worldwide.

CXR imaging is commonly used for the diagnosis of cardio and/or respiratory abnormalities; it is capable of identifying multiple conditions through a single shot, i.e., COVID-19, pneumonia, heart enlargement, etc. [6]. There exists multiple public CXR datasets [20, 61]. However, the creation of large comprehensive medical datasets is labour intensive, and requires significant medical resources which are usually scarce [9]. Consequently, medical datasets are rarely as abundant as those for non-medical fields. Thus, machine learning approaches applied on medical datasets need to address the problem of data scarcity. In this paper, we exploit the multi-task learning for a solution.

Multi-task learning is known for its inductive transfer characteristics that can drive strong representation learning and generalization of each component task [8]. Therefore, multi-task learning methods partially alleviates some of the major shortcomings in deep learning, i.e., high demands for data sufficiency and heavy computation loads [11]. However, to apply multi-task learning methods successfully, challenges still exist, which can be the proper selection of component tasks, the architecture of the network, the optimization of the training schemes and many others [11, 62]. This paper investigates the proper configuration of a multi-task learning model that can tackle visual saliency prediction and image classification simultaneously.

The main contributions of this paper are: 1) development of a new deep convolutional neural network (DCNN) architecture for CXR image saliency prediction and classification based on UNet [47], and 2) proposal of an optimized multi-task learning scheme that handles overfitting. Our method aims to outperform the state-of-the-art networks dedicated either for saliency prediction or image classification.

2 Background

2.1 Saliency Prediction with Deep Learning

DCNN is the leading machine learning method applied to saliency prediction [22, 30, 31, 43]. Besides, transfer learning with pre-trained networks was observed to boost the performance of saliency prediction [31, 41, 42]. A majority of DCNN approaches are for natural scene saliency prediction, and so far, only a few studied the saliency prediction for medical images. By [5], the generative adversarial network is used to predict expert sonographer’s saliency when performing standard fetal head plane detection on ultrasound (US) images. However, the saliency prediction is used as a secondary task to assist the primary detection task, and thus, the saliency prediction performance failed to outperform benchmark prediction methods in several key metrics. Similarly, by [25], as a proof-of-concept study, the gaze data is used as an auxiliary task for CXR image classification, and the performance of saliency prediction is not reported in the study.

2.2 CXR Image Classification with Deep Learning

Public datasets for CXR images enabled data driven approaches for automatic image analysis and diagnosis [33, 50]. Advancements in standardized image classification networks, i.e., ResNet [18], DenseNet [19], and EfficientNet [55], facilitate CXR image classification. Yet, CXR image classification remains challenging, as CXR images are noisy, and may contain subtle features that are difficult to recognize even by experts [6, 28].

3 Multi-task Learning Method

As stated in Sect. 1, component task selection, network architecture design, and training scheme are key factors for multi-task learning. We select the classification task together with the saliency prediction based on the fact that attention patterns are task specific [26]. Radiologists are likely to exhibit distinguishable visual behaviors when different patient conditions are shown on CXR images [38]. This section introduces our multi-task UNet (MT-UNet) architecture, and derives a better multi-task training scheme for saliency prediction and image classification.

Fig. 1.
figure 1

MT-UNet architecture. The solid blocks represent 3D tensors, \(\mathbf {R}^{F\times H\times W}\), where F, H, and W denote feature (channel), height and width dimensions, respectively. The solid circles represent 1D tensors. Arrows denote operations to the tensors. Numbers above some of the solid blocks stand for the number features in tensors.

3.1 Multi-task UNet

Figure 1 shows the architecture of the proposed MT-UNet. The network takes CXR images, \(\boldsymbol{x}\in \mathbf {R}^{1\times H\times W}\), where H and W are image dimensions, as input, and produces two outputs, predicted saliency \(\boldsymbol{y}_s\in \mathbf {R}^{1\times H\times W}\), and predicted classification \(\boldsymbol{y}_c\in \mathbf {R}^{C}\), where C is the number of classes. As the ground truth for \(\boldsymbol{y}_s\) is human visual attention distribution, represented as a 2D matrix whose elements are non-negative and sum to 1, \(\boldsymbol{y}_s\) is normalized by Softmax before output from MT-UNet. Softmax is also applied to \(\boldsymbol{y}_c\) before output so that the classification outcome can be interpreted as class probability. For the simplicity of notation, batch dimensions are neglected.

The proposed MT-UNet is derived from standard UNet architecture [47]. As a well-known image-to-image deep learning model, the UNet structure has been adopted for various tasks. For example, the UNet is appended with additional structures for visual scene understanding [21], the features from the bottleneck (middle of the UNet) are extracted for image classification tasks [25], and by combining UNet with Pyramid Net [35], features at different depth are aggregated for enhanced segmentation [40]. What’s more, the encoder-decoder structure of UNet is utilized for multi-task learning, where the encoder structure is used to learn representative features, along with designated decoder structures or classification heads for image reconstruction, segmentation, and/or classification [1, 64]. In our design, we apply classification heads (shaded in light green in Fig. 1), which are added not only to the bottleneck but also the ending part of the UNet architecture. This additional classification specific structure aggregates middle and higher-level features for classification, exploiting features learnt at different depths. The attention heads perform global average pooling operations to the 4D tensors, followed by concatenation, and two linear transforms (dense layers) with dropout (rate = \(25\%\)) in the middle to produce classification outcomes. The MT-UNet belongs to the hard parameter sharing structure in multi-task learning, where different tasks share the same trainable parameters before branched out to each tasks’ specific parameters [58]. Having more trainable parameters in task specific structures may improve the performance for that task at the cost of introducing additional parameters and increasing computational load [11, 58]. In our design, we wish to avoid heavy structures with lots of task specific parameters, and therefore, task specific structures are minimized. In Fig. 1, we use yellow and green shades to denote network structures dedicated for saliency prediction and classification, respectively.

3.2 Multi-task Training Scheme

Balancing the losses between tasks in a multi-task training process has a direct impact on the training outcome [58]. There exist multi-task training schemes [10, 16, 27, 49], and among which, we adopt the uncertainty based balancing scheme [27] with the modification used in [34, 65]. Hence, the loss function is:

$$\begin{aligned} \mathcal {\boldsymbol{L}} = \frac{1}{\sigma _s^2}L_s+\frac{1}{\sigma _c^2}L_c+\ln (\sigma _s+1)+\ln (\sigma _c+1) \end{aligned}$$
(1)

where \(L_s\) and \(L_c\) are loss values for \(\boldsymbol{y}_s\) and \(\boldsymbol{y}_c\), respectively; \(\sigma _s>0\) and \(\sigma _c>0\) are trainable scalars estimating the uncertainty of \(L_s\) and \(L_c\), respectively; \(\sigma _s\) and \(\sigma _c\) are initialized to 1; \(\ln (\sigma _s+1)\) and \(\ln (\sigma _c+1)\) are regularizing terms to avoid arbitrary decrease of \(\sigma _s\) and \(\sigma _c\). With Eq. 1, we know that \(\sigma \) values can dynamically weigh losses of different amplitudes during training, and loss with low uncertainty (small \(\sigma \) value) is prioritized in the training process. \(\mathcal {\boldsymbol{L}}>0\). Given \(\boldsymbol{y}_s\) and \(\boldsymbol{y}_c\) with their ground truth \(\bar{\boldsymbol{y}}_s\) and \(\bar{\boldsymbol{y}}_c\), respectively, the loss functions are:

$$\begin{aligned} L_s = H(\bar{\boldsymbol{y}}_s, {\boldsymbol{y}}_s)-H(\bar{\boldsymbol{y}}_s), \end{aligned}$$
(2)
$$\begin{aligned} L_c = H(\bar{\boldsymbol{y}}_c, {\boldsymbol{y}}_c) \quad \quad \quad \quad \end{aligned}$$
(3)

where \(H(Q,R)=-\varSigma _{i}^nQ_i\ln (R_i)\) stands for cross entropy of two discrete distributions Q and R, both with n elements; \(H(Q)=H(Q,Q)\) stands for the entropy, or self cross entropy, of discrete distribution Q. \(L_s\) is the Kullback-Leibler divergence (KLD) loss, and \(L_c\) is the cross-entropy loss. By observing Eq. 2 and Eq. 3, we know that only the cross entropy terms, \(H(\cdot , \cdot )\), generate gradient when updating network parameters, as the term \(-H(\bar{\boldsymbol{y}}_s)\) in \(L_s\) is a constant and has zero gradient. Therefore, we extend the method in [27], and use \(\frac{1}{\sigma ^2}\) to scale a KLD loss (\(L_s\)) as that for a cross-entropy loss (\(L_c\)).

Although the training scheme in Eq. 1 yields many successful applications, overfitting for multi-task networks still can jeopardize the training process, especially for small datasets [59]. Multiple factors can cause overfitting, among which, learning rate, \(r>0\), shows the most significant impact [32]. Also, r generally has significant influences on the training outcome [52], making it one of the most important hyper-parameters for a training process. When training MT-UNet, r is moderated by several factors. The first factor is the use of an optimizer. Many optimizers, i.e., Adam [29] and RMSProp [57], deploy the momentum mechanism or its variants, which can adaptively adjust the effective learning rate, \(r_e\), during training. As a learning rate scheduler is often used for more efficient training, it is the second factor to influence r. The influence of r from a learning rate scheduler can be adaptive, i.e., reduce learning rate on plateau (RLRP), or more arbitrary, i.e., cosine annealing with warm restarts [37]. By observing Eq. 1, we know that an uncertainty estimator \(\sigma \) for a loss L also serves as a learning rate adaptor for L, which is the third factor. More specifically, given a loss value L with learning rate r, the effective learning rate for parameters with a scaled loss value \(\frac{L}{\sigma ^2}\) is \(\frac{r}{\sigma ^2}\).

Decreasing r upon overfitting can alleviate its effects [12, 52], but Eq. 1 leads to increased learning rate upon overfitting, further worsening the training process. This happens because training loss decreases when overfitting occurs, reducing its variance at the same time. Thus, \(\sigma \) decreases accordingly, which increases the effective learning rate, thus creating a vicious circle of overfitting. This phenomenon can be observed in Fig. 2, where changes of losses and \(\sigma \) values during a training process following Eq. 1 are presented. We can see from Fig. 2(a), at epoch 40, after an initial decrease in both the training and validation losses, the training loss start to decrease acceleratedly while the validation loss start to amplify, which is a vicious circle of overfitting. A RLRP scheduler can halt the vicious circle by resetting the model parameters to a former epoch and reducing r. Yet, even with reduced r, a vicious circle of overfitting can remerge in later epochs. The mathematical proof of the aforementioned vicious circle of overfitting is presented in Appendix A.

Fig. 2.
figure 2

Training process visualization with Eq. 1

To alleviate overfitting, we propose the use of the following equations to replace Eq. 1:

$$\begin{aligned} \mathcal {\boldsymbol{L}} = \frac{1}{\sigma _s^2}L_s+L_c+\ln (\sigma _s+1), \end{aligned}$$
(4)
$$\begin{aligned} \mathcal {\boldsymbol{L}} = L_s+\frac{1}{\sigma _c^2}L_c+\ln (\sigma _c+1). \end{aligned}$$
(5)

The essence of Eqs. 4 and 5 is to fix the uncertainty term for one loss in Eq. 1 to 1, so that the flexibility in changing effective learning rate is reduced. With the uncertainty term fixed for one component loss, Eqs. 4 and 5 demonstrate the ability to alleviate overfitting and stabilize the training process. It is worth noting that Eqs. 4 and 5 cannot be used interchangeably. We need to test both equations to check which can achieve better performances, as depending on the dataset and training process, overfitting can occur of different severity in all component tasks. In this study, the training process with Eq. 5 achieves the best performance. Ablation study of this method is presented in Sect. 5.

4 Dataset and Evaluation Methods

We use the “chest X-ray dataset with eye-tracking and report dictation” [25] shared via PhysioNet [39] in this study. The dataset was derived from the MIMIC-CXR dataset [23, 24] with additional gaze tracking and dictation from an expert radiologist. 1083 CXR images are included in the dataset, and accompanying each image, there are tracked gaze data; a diagnostic label (either normal, pneumonia, or enlarged heart); segmentation of lungs, mediastinum, and aortic knob; and radiologist’s audio with dictation. The CXR images in the dataset are in resolutions of various sizes, i.e., \(3056\times 2044\), and we down sample and/or pad each image to \(640\times 416\). A GP3 gaze tracker by Gazepoint (Vancouver, Canada) was used for the collection of gaze data. The tracker has an accuracy of around 1\(^\circ \) of visual angle, and has a 60 Hz sampling rate [66].

Several metrics have been used for the evaluation of saliency prediction performances, and they can be classified into location-based metrics and distribution-based metrics [4]. Due to the tracking inaccuracy of the GP3 gaze tracker, location-based metrics are not suited for this study. Therefore, in this paper, we follow suggestions in [4] and use KLD for performance evaluation. We also include histogram similarity (HS), and Pearson’s correlation coefficient (PCC) for reference purposes. For the evaluation of classification performances, we use the area under curve (AUC) metrics for multi-class classifications [14, 17], and the classification accuracy (ACC) metrics. We also include the AUC metrics for each class: normal, enlarged heart, and pneumonia, denoted as AUC-Y1, AUC-Y2, and AUC-Y3, respectively. In this paper, all metrics values are presented as median statistics followed by standard deviations behind the ± sign. Metrics with up-pointing arrow \(\uparrow \) indicates greater values reflect better performances, and vice versa. Best metrics are emboldened.

5 Experiments and Result

5.1 Benchmark Comparison

In this subsection, we compare the performance of MT-UNet, with benchmark networks for CXR image classification and saliency prediction. Detailed training settings are presented in Appendix B.

For CXR image classification, the benchmark networks are chosen from the top performing networks for CXR image classification examined in [13], which are ResNet50 [18] and Inception-ResNet v2 (abbreviated as IRNetV2 in this paper) [54]. Following [25], we also include a state-of-the-art general purpose classification network: EfficientNetV2-S (abbreviated as EffNetV2-S) [56] for comparison. For completeness, classification using standard UNet with additional classification head (denoted as UNetC) is included. Results are presented in Table 1, and We can see that MT-UNet outperforms the other classification networks.

For CXR image saliency prediction, comparison was conducted with 3 state-of-the-art saliency prediction models, which are SimpleNet [46], MSINet [30] and VGGSSM [7]. Saliency prediction using standard UNet (denoted as UNetS) is also included for reference. Table 2 shows the result, where MT-UNet outperforms the rest. Visual comparisons for saliency prediction results are presented through Table 4 in Appendix C.

Table 1. Performance comparison between classification models.
Table 2. Performance comparison between saliency prediction models.

5.2 Ablation Study

To validate the modified multi-task learning scheme, ablation study is performed. The multi-task learning schemes following Eqs. 1, 4, and 5 are compared, and they are denoted as MTLS1, MTLS2, and MTLS3, respectively. Please note that the best-performing MTLS3 is used for benchmark comparison in Sect. 5.1. Figure 3 shows the training process for MTLS2 and MTLS3. With Figs. 2 and 3, we can see that overfitting occurs both for MTLS1 and MTLS2, but the overfitting is reduced in MTLS3. The training processes shown in Figs. 2 and 3 are with optimized hyper-parameters. The resulting performances are compared in Table 3. We can see that MTLS3 outperforms the rest learning schemes both in classification and in saliency prediction.

To validate the effects of using classification head that aggregates features from different depths, we create ablated versions of MT-UNet that use features from either the bottleneck or the top layer of the MT-UNet for classification, denoted as MT-UNetB and MT-UNetT, respectively. Results are presented in Table 3. We can see that MT-UNet generally performs better than MT-UNetT and MT-UNetB.

Table 3. Ablation study performance comparison.
Fig. 3.
figure 3

Multi-task learning schemes comparison

6 Discussion

In this paper, we build the MT-UNet model and propose a further optimized multi-tasking learning scheme for saliency prediction and disease classification with CXR images. While a multi-task learning model has the potential of enhancing the performances for all component tasks, a proper training scheme is one of the key factors to fully unveil its potentiality. As shown in Table 3, MT-UNet with the standard multi-task learning scheme may barely outperform existing models for saliency prediction or image classification.

Several future work could be done to improve this study. The first would be the expansion of the gaze tracking dataset for medical images. So far, only 1083 CXR images are publicly available with radiologist’s gaze behavior, limiting extensive studies of gaze-tracking assisted machine learning methods in the medical field. Also, more dedicated studies on multi-task learning methods, especially for small datasets, can be helpful for medical machine learning tasks. Overfitting and data deficiency are the lingering challenges encountered by many studies. A better multi-task learning method may handle these challenges more easily.