0154

Avoiding shortcut-learning by mutual information minimization in deep learning-based MR image processing
Louisa Fay1,2, Bin Yang2, Sergios Gatidis1,3, and Thomas Kuestner1
1Medical Image and Data Analysis (MIDAS.lab), Department of Diagnostic and Interventional Radiology, University Hospital of Tuebingen, Tuebingen, Germany, 2Institute of Signal Processing and System Theory, University of Stuttgart, Stuttgart, Germany, 3Max Planck Institute for Intelligent Systems, Tuebingen, Germany

Synopsis

Keywords: Machine Learning/Artificial Intelligence, Data Analysis

Deep Learning methods can detect patterns in data such as MR images but are incapable of determining causal relationships. However, causal understanding is crucial in medical applications, since the presence of confounders (e.g. scan conditions) obscure the causal relationship and create spurious-correlations. State-of-the-art models purely rely on correlated patterns which can result in wrong conclusions or diagnoses when spurious-correlations change (e.g. new scanner). We propose a deep learning framework that is robust in the presence of spurious-correlations by decreasing mutual information between learned features of MR images and leads to improved performance under distribution shifts.

Introduction

The promising results of deep learning (DL) in the medical field have the potential to fundamentally transform and support clinical workflow1. DL methods are able to detect correlations in data but are incapable of determining their causal meaning2. Purely relying on learned correlated patterns could lead to wrong or biased conclusions or diagnoses with DL. Furthermore, in new environments with distribution shifts, these models will fail to correctly predict outputs since they do not capture the real causal relationship between input and output3,4.

Especially in MRI, many factors such as scanner, acquisition-site, scan conditions, patient compliance, etc. influence the perceived image5. Consequently, DL models trained on this data tend to learn shortcuts and spurious-correlations instead of task-specific features (Fig.1)6. For example, a hospital scans mainly male patients on machine A, while females are primarily scanned on machine B. Both machines are from one vendor, but machine A is older and produces noisier images. For DL models, it is easier to identify the gender based on the spurious-correlation between gender and noise level instead of learning the desired causal relationship based on gender and anatomic features.
Since many medical databases contain confounding factors, it is important to improve the model’s robustness to make correct predictions in every environment7.

To address this challenge, we propose a framework that predicts the desired outcome while simultaneously reducing the influence of spuriously-correlated factors. The proposed method encodes MR images to a feature vector (FV) and splits it to predict the primary task (PT) and the perceived spuriously-correlated factor (SC) separately. The aim is to reduce the mutual information (MI) between both FV parts to achieve independence between PT and SC features. The MI is learned by a neural estimation of mutual information (MINE)8. The proposed approach is investigated on ~8300 cases of brain MRIs of the UK Biobank (UKB)9 and the German National Cohort (NAKO)10.

Methods

Model: Our proposed model, called Mutual Information Minimization model (MIM model) (Fig.2B), operates on a feature-encoder that encodes 2D MR images with four convolutional layers followed by four fully-connected layers. The resulting FV is split into two parts of equal length. The upper part XPT predicts the primary task (PT), while the lower part XSC predicts the spuriously-correlated factor (SC).
Since the aim is to obtain independence between XPT and XSC, we propose to minimize their MI. MI is learned by MINE8 between XPT and XSC. The estimated MI is used as penalty term in the loss (here: mean cross-entropy) to minimize MI while solving the PT and SC tasks. Training is performed asynchronously where the feature-encoder is trained for one batch followed by two batch-updates of the MINE model.
To demonstrate the impact of MI penalization, a baseline model without MI penalty is trained (Fig.2A).

Settings: All models are trained with 5-fold-cross validation using ADAM optimizer11 (learning rate 10-4, batch size 420).

Data: Training is performed on 8326 cases of T1-weighted 3D MPRAGE brain MRIs (resolution 1x1x1 mm3) from the UKB and NAKO acquired on 3T MRI (Siemens Skyra).

Experiments: Investigations are carried out for two tasks predicting the sex and age of subjects, respectively. In task 1, the spuriously-correlated factors are the acquisition sites: UKB, NAKO. Task 2 is spuriously-correlated by sex. For both tasks, a spuriously-correlated training dataset with imbalanced classes is created. Task 1: 90% of female subjects are from UKB and 10% from NAKO (opposed for male subjects). Task 2: Subjects are split into two age groups (young: ≤51 years, old: >57 years). The imbalanced dataset distribution of task 2 is visualized in Fig.3.

Evaluation: To evaluate the trained models and to demonstrate the suppressed shortcut-learning, the classification accuracy is computed for a) new unseen samples of a validation dataset with the same distribution as the training set, b) a test set with a flipped distribution (Fig.3B) and c) a dataset with balanced classes. To show that XPT and XSC are independent and do not share information, XPT is used to predict SC, while XSC predicts PT. Thereby, an accuracy close to random guess (50%) is desired.

Results and Discussion

As the table in Fig.4 presents, the proposed MIM model achieves for both tasks a significantly better performance than the baseline. For task 1, the baseline has an accuracy drop of 15% when flipping the data distribution, whereas the accuracy of the proposed MIM model drops by <0.1%. The same trend is observed on the balanced dataset. With respect to independence between XPT and XSC, the proposed MIM model achieves 54% accuracy, while the baseline model still reaches around 63%. The t-SNE visualization12 in Fig.5B-C demonstrates the removal of spuriously-correlated information in XPT. While the baseline is still able to cluster SC from XPT (Fig.5B), the t-SNE of MIM model shows that XPT is free from any SC information as the classes of SC are not distinguishable anymore (Fig.5C). The results of task 2 show the same trend.

Conclusion

We proposed and evaluated a framework (MIM model) that allows to remove spurious-correlation in MRI learned by DL models which would otherwise lead to wrong predictions in new environments and out-of-distributional data (e.g. application to other hospitals or scanners).

Acknowledgements

S. G. and T. K. contributed equally.

This project was conducted with data from the German National Cohort (GNC) (www.nako.de). The GNC is funded by the Federal Ministry of Education and Research (BMBF) (project funding reference no. 01ER1301A/ B/C and 01ER1511D), federal states, and the Helmholtz Association, with additional financial support from the participating universities and institutes of the Leibniz Association. We thank all participants who took part in the GNC study and the staff in this research program9.

This work was carried out under UK Biobank Application 40040. We thank all participants who took part in the UKB study and the staff in this research program10.

References

  1. Hinton G. Deep Learning—A Technology With the Potential to Transform Health Care. JAMA [Internet]. 2018;320 doi: 10.1001/jama.2018.11100.
  2. Gary M. Deep Learning: A Critical Appraisal. CoRR. 2018;abs/1801.00631.
  3. D’Amour A, Eisenstein J, Yadlowsky S, Veitch V. Counterfactual Invariance to Spurious Correlations: Why and How to Pass Stress Tests. CoRR. 2021;abs/2106.00545.
  4. Beery S, van Horn G, Perona P. Recognition in Terra Incognita [Internet]. In: ECCV; 2018.
  5. Rao A, Monteiro JM, Mourao-Miranda J, Alzheimer’s DI, others. Predictive modelling using neuroimaging data in the presence of confounds. Neuroimage. 2017;150:23–49.
  6. Geirhos R, Jacobsen J-H, Michaelis C, Zemel R, Brendel W, Bethge M, Wichmann FA. Shortcut learning in deep neural networks. Nature Machine Intelligence [Internet]. 2020;2(11):665–673 doi: 10.1038/s42256-020-00257-z.
  7. Rassen J, Brookhart AM, Glynn RJ, Schneeweiss S, Stürmer T. Confounding Control in Healthcare Database Research. Medical Care [Internet]. 2010;48(6):S114‐S120 doi: 10.1097/mlr.0b013e3181dbebe3.
  8. Belghazi MI, Baratin A, Rajeshwar S, Ozair S, et al., Mutual Information Neural Estimation [Internet]. In: Dy J, Krause A, editors. Proceedings of the 35th International Conference on Machine Learning: PMLR; 2018. p. 531–540. Available from: https://proceedings.mlr.press/v80/belghazi18a.html.
  9. Sudlow C, Gallacher J, Allen N et al. UK Biobank: An Open Access Resource for Identifying the Causes of a Wide Range of Complex Diseases of Middle and Old Age. PLOS Medicine [Internet]. 2015;12(3):1–10 doi: 10.1371/journal.pmed.1001779.
  10. The German National Cohort: aims, study design and organization. European Journal of Epidemiology [Internet]. 2014;29(5):371–382 doi: 10.1007/s10654-014-9890-7.
  11. Kingma DP, Ba J. Adam: A Method for Stochastic Optimization: arXiv; 2014. van der Laurens M, Geoffrey H. Visualizing Data using t-SNE. Journal of Machine Learning Research [Internet]. 2008;9(86):2579–2605. Available from: http://jmlr.org/papers/v9/vandermaaten08a.html.
  12. van der Laurens M, Geoffrey H. Visualizing Data using t-SNE. Journal of Machine Learning Research [Internet]. 2008;9(86):2579–2605. Available from: http://jmlr.org/papers/v9/vandermaaten08a.html
  13. Alexander L. NN-SVG: Publication-Ready Neural Network Architecture Schematics. Journal of Open Source Software [Internet]. 2019;4(33):747 doi: 10.21105/joss.00747.

Figures

Fig.1: Visualization of the causal direction, the spurious-correlation and possible DL model prediction paths when confounding is present in data. (A) Currently: Confounding causes a spurious-correlation, which bears the risk that a DL model learns a shortcut over a spuriously-correlated factor (e.g. scan conditions) instead of the real causal relationship (e.g. anatomic features) between input and output. (B) Proposed: By interrupting the shortcut, learning a spurious-correlation is avoided. Hence, the DL model is able to learn the real causal relationship.

Fig.2: (A) Baseline model: The feature-encoder encodes the MR input image, resulting in a FV which is split into two parts. The upper part XPT predicts the desired task (PT), while the lower part XSC predicts the spuriously-correlated factor (SC). (B) MIM model (proposed): The MI between XPT and XSC is estimated by a MINE model. The aim is to minimize the MI between the two FV parts (a small MI indicates independence between vectors). For this purpose, MI is added as penalty to the loss13.

Fig.3: Data distribution during training and testing for the binary classification of age groups, where group young is under the age of 51 and group old is above the age of 57. (A) Confounded training dataset; younger subjects: 90% female, 10% male and older subjects: 90% male, 10% female. (B) Confounded test dataset; creating a new environment by flipping the class ratios. Younger subjects: 10% female, 90% male, older subjects: 10% male, 90% female.

Fig.4: Accuracy comparison of the baseline and proposed MIM model. For XPT predicting SC and XSC predicting PT, a value close to 50% indicates that XPT (XSC) does not contain any information about SC (PT), respectively. The proposed model performs significantly better than the baseline and due to smaller accuracy drops from validation to test dataset the removal of spurious-correlation (shortcut-learning) (task 1: scanner, task 2: sex) is demonstrated.

Fig.5: Results of task 1 for prediction of sex with spuriously-correlated factor type of scanner. (A) In-domain samples are predicted correctly for both models, but out-of-domain samples are only correctly predicted by the proposed model. (B)+(C) t-SNE plot of XPT (responsible for prediction of sex, colored by the labels of the spuriously-correlated factor) (B) Baseline: XPT still contains scanner information since the scanners are clearly distinguishable. (C) Our proposed model is robust and removed the spurious-correlation from XPT since the scanners are not separable.

Proc. Intl. Soc. Mag. Reson. Med. 31 (2023)
0154
DOI: https://doi.org/10.58530/2023/0154