0382

Explaining Deep fMRI Classifiers with Diffusion-Driven Counterfactual Generation
Hasan Atakan Bedel1,2 and Tolga Çukur1,2,3
1Department of Electrical and Electronics Engineering, Bilkent University, Ankara, Turkey, 2National Magnetic Resonance Research Center (UMRAM), Bilkent University, Ankara, Turkey, 3Neuroscience Program, Bilkent University, Ankara, Turkey

Synopsis

Keywords: AI Diffusion Models, Machine Learning/Artificial Intelligence, fMRI, xAI, diffusion, transformers

Motivation: Deep-learning classifiers for functional MRI (fMRI) offer state-of-the-art performance in detection of cognitive states from BOLD responses, but their black-box nature hinders interpretation of results.

Goal(s): Our goal was to devise a reliable method to infer the important BOLD-response attributes that drive the decisions of deep fMRI classifiers.

Approach: We introduced a novel counterfactual explanation method (DreaMR) based on a new fractional, distilled diffusion prior for efficient generation of high-fidelity counterfactual samples.

Results: DreaMR generated more specific and plausible explanations of deep fMRI classifiers trained for resting-state and task-based fMRI analysis than previous state-of-the-art explanation methods.

Impact: The improvement in sensitivity, plausability and efficiency in explanation of deep classifiers through DreaMR may facilitate adoption of AI-based analyses in fMRI studies, thereby benefiting assessment of cognitive processes in both normal and neurological disease states.

Introduction

Functional MRI (fMRI) recordings of brain responses enable non-invasive assessment of cognitive states evoked during resting state or task execution1-9. In recent years, deep-learning (DL) classifiers have enabled sensitivity leaps in detection of cognitive states from fMRI scans10-16. Yet, DL classifiers involve hierarchical nonlinear transformations that obscure the links between brain responses and cognitive states10. As this creates a barrier to user trust, there is a dire need for methods that explain the decisions of deep fMRI classifiers15,16. Among current explanation frameworks, counterfactual methods17 that seek the minimum set of changes in brain responses sufficient to distinguish between separate cognitive states stand out, given their elevated sensitivity over attribution or perturbation methods18-20. However, previous counterfactual methods for fMRI use variational or adversarial priors that typically lack sample fidelity, compromising specificity and plausibility of subsequent interpretation21.

To address this limitation, here we introduce the first diffusion-driven counterfactual explanation method (DreaMR) for fMRI classifiers, to our knowledge. Combining a novel efficient diffusion prior with a spatiotemporal transformer architecture, DreaMR outperforms competing methods in counterfactual explanation.

Methods

DreaMR: DreaMR explains fMRI classifiers by generating counterfactual samples via a task-agnostic prior on fMRI data (Fig.1). It employs a novel fractional multi-phase-distilled diffusion (FMD) prior for high sampling fidelity and efficiency. Given a trained cognitive-state classifier, DreaMR resamples the original fMRI scan to find minimal BOLD response changes that alter the classifier's decision. Differences between the original and counterfactual fMRI scans are analyzed to identify important response features for each cognitive state.

a) FMD prior: Common diffusion priors transform Gaussian noise onto data over hundreds of steps, and they suffer from poor sample quality when accelerated22. To enhance efficiency and sample quality, we partition the diffusion process into \(F\) fractions, each with a distinct denoising network \(\mathbf{D}_{\theta}^{[f]}(x_t)\), where \(x_t\) is the noise-added fMRI sample at step \(t \in [0, T]\) (\(T\): total steps). The training objective is:\[\min_{\theta}\sum_{f=1}^{F}\mathbb{E}_{t \sim U[t_s(f),t_e(f)]}\left[\|\mathbf{D}_{\theta}^{[f]}(x_t)-x_0 \|_2^2\right]\]where \(x_0\) denotes original fMRI scans, \(t_s(f)\), \(t_e(f)\) denote the start-end times for fraction \(f\). The original teacher-networks \(\mathbf{D}_{\theta_0}^{[f]}(x_t)=\mathbf{D}_{\theta}^{[f]}(x_t)\) are iteratively distilled onto student-networks \(\mathbf{D}_{\theta_P}^{[f]}\) over \(P\) phases, doubling the step size in each phase as $$$k=2^p$$$. The objective in phase \(p\) is:\[\min_{\theta_p}\mathbb{E}_{t \sim U(t_s(f):T/(2^pF):t_e(f))}\left[\|\mathbf{D}^{[f]}_{\theta_p}(x_t)-\widetilde{x}_{0} \|_2^2\right]\]where $$$\widetilde{x}_{0}$$$ is a reference sample drawn using $$$\mathbf{D}^{[f]}_{\theta_{p-1}}(x_t)$$$23.

b) Counterfactual generation: Given an original fMRI scan $$$x_0$$$ mapped to cognitive state $$$y_0$$$, DreaMR generates a counterfactual scan $$$\bar{x}_0$$$ to elicit a target state $$$\bar{y}_0\neq y_0$$$. Following addition of Gaussian noise ($$$\epsilon$$$) onto $$$x_0$$$ to form $$$\bar{x}_{\Delta T}=\alpha_{\Delta T}x_0+\sigma_{\Delta T}\epsilon$$$, nested diffusion sampling is performed. An inner loop estimates \(\widetilde{x}_0\) via the diffusion prior to enable computation of classifier guidance \(G\) as the log-posterior probability. An outer loop resamples via the diffusion prior complemented with \(G\) to obtain $$$\bar{x}_0$$$:\begin{align*}&\text{For }t_o \text{ in range}(\Delta T,-1,-k):\\&\quad \widetilde{x}_{t_o}=\bar{x}_{t_{o}},\\&\quad \text{For }t_i \text{ in range}(t_o, -1, -k):\\& \quad\quad \widetilde{x}_{t_i-k}=\alpha_{t_i-k}\widehat{x}_{0}+\sigma_{t_i-k}\left(\frac{\widetilde{x}_{t_i}-\alpha_{t_i}\widehat{x}_{0}}{\sigma_{t_i}}\right);\mbox{ }\widehat{x}_{0}=D_{\theta_P}^{[f]}(\widetilde{x}_{t_i});\\&\quad G=\nabla_{\widetilde{x}_0}\log p_c(\bar{y}|\widetilde{x}_0),\\&\quad \bar{x}_{t_o-k}=\alpha_{t_o-k}\widehat{x}_{0}+\sigma_{t_o-k}\left(\frac{\bar{x}_{t_o}-\alpha_{t_o}\widehat{x}_{0}}{\sigma_{t_o}}\right)+\gamma G;\mbox{ }\widehat{x}_{0}=D_{\theta_P}^{[f]}(\bar{x}_{t_o});\end{align*}As \(\bar{x}_{0}\) minimally modulates \(x_0\) to alter the classifier decision, it enables inference of important brain response features that affect cognitive states.

Analyses: Models were trained on resting-state (rs) fMRI data from HCP-Rest24 (1093 scans) and tested on task-fMRI data from HCP-Task24 (7450 scans) and rs-fMRI data from ID100025 (881 scans). DreaMR utilized a transformer architecture16, adopted for diffusion modeling by incorporating time encoding. Cross-validated hyperparameters were $$$T=1024$$$, $$$F=4$$$, $$$P=7$$$, $$$k=128$$$, and $$$s=32$$$. Training was performed via Adam optimizer for \(100\) epochs, \(0.0002\) learning rate, \(8\) batch size.

Results

We compared DreaMR against leading counterfactual methods utilizing autoencoder (LatentShift28) and diffusion (DVCE27, DiME26, DiffSCM22) priors to explain transformer-based cognitive-state classifiers16. DreaMR generally achieves the highest specificity (i.e., lowest proximity measured as distance, and sparsity measured as percentage of significantly-altered features between original-counterfactual samples), and the highest plausibility (i.e., lowest FID measured between original-counterfactual distributions) (Fig.3a). DreaMR yields much faster inference than diffusion priors, comparable efficiency to autoencoder priors (Fig.3b).

Fig.4 displays explanation maps of BOLD responses (i.e., difference between original-counterfactual samples) for a cognitive-task classifier on task-fMRI data. Fig.5 shows explanation maps for functional connectivity (i.e., difference between correlation-based FC matrices for original-counterfactual samples). DreaMR shows superior alignment with characteristic response and FC patterns, corroborating the neuroscience literature29. When important features in explanation maps are used for cognitive-state detection via a linear classifier, DreaMR-based features yield 10% average improvement in classification accuracy over competing methods.

Discussion

DreaMR leverages a powerful diffusion prior to efficiently produce high-quality counterfactuals that help identify associations between brain responses and cognitive states. Demonstrations on rs- and task-fMRI datasets suggest that DreaMR surpasses previous state-of-the-art in terms of interpretation performance, offering great promise in explainable fMRI analysis with deep-learning classifiers.

Acknowledgements

This work was supported in part by a TUBITAK 1001 Grant No. 121N029, and in part by a TUBITAK BIDEB scholarship.

References

[1] J. W. Belliveau, D. N. Kennedy, R. C. McKinstry, B. R. Buchbinder, R. M. Weisskoff, M. S. Cohen, J. M. Vevea, T. J. Brady, and B. R. Rosen, “Functional Mapping of the Human Visual Cortex by Magnetic Resonance Imaging,” Science, vol. 254, no. 5032, pp. 716–719, 1991.

[2] S. Ogawa, D. W. Tank, R. Menon, J. M. Ellermann, S. G. Kim, H. Merkle, and K. Ugurbil, “Intrinsic signal changes accompanying sensory stimulation: Functional brain mapping with magnetic resonance imaging,” Proc. Natl. Acad. Sci. U. S. A., vol. 89, no. 13, pp. 5951–5955, 1992.

[3] S. D. Forman, J. D. Cohen, M. Fitzgerald, W. F. Eddy, M. A. Mintun, and D. C. Noll, “Improved Assessment of Significant Activation in Functional Magnetic Resonance Imaging (fMRI): Use of a Cluster‐Size Threshold,” Magn. Reson. Med., vol. 33, no. 5, pp. 636–647, 1995.

[4] K. L. Miller, B. A. Hargreaves, J. Lee, D. Ress, R. C. DeCharms, and J. M. Pauly, “Functional brain imaging using a blood oxygenation sensitive steady state,” Magn. Reson. Med., vol. 50, no. 4, pp. 675–683, Oct. 2003.

[5] S. Makni, J. Idier, T. Vincent, B. Thirion, G. Dehaene-Lambertz, and P. Ciuciu, “A fully Bayesian approach to the parcel-based detection-estimation of brain activity in fMRI,” Neuroimage, vol. 41, no. 3, pp. 941–969, Jul. 2008.

[6] M. Bianciardi, M. Fukunaga, P. van Gelderen, S. G. Horovitz, J. A. de Zwart, K. Shmueli, and J. H. Duyn, “Sources of functional magnetic resonance imaging signal fluctuations in the human brain at rest: a 7 T study,” Magn. Reson. Imaging, vol. 27, no. 8, pp. 1019–1029, Oct. 2009.

[7] C. Y. Wee, P. T. Yap, D. Zhang, K. Denny, J. N. Browndyke, G. G. Potter, K. A. Welsh-Bohmer, L. Wang, and D. Shen, “Identification of MCI individuals using structural and functional connectivity networks,” Neuroimage, vol. 59, no. 3, pp. 2045–2056, Feb. 2012.

[8] W. T. Chang, A. Nummenmaa, T. Witzel, J. Ahveninen, S. Huang, K. W. K. Tsai, Y. H. Chu, J. R. Polimeni, J. W. Belliveau, and F. H. Lin, “Whole-head rapid fMRI acquisition using echo-shifted magnetic resonance inverse imaging,” Neuroimage, vol. 78, pp. 325–338, Sep. 2013.

[9] F. Wang, Z. Dong, L. L. Wald, J. R. Polimeni, and K. Setsompop, “Simultaneous pure T2 and varying T2′-weighted BOLD fMRI using Echo Planar Time-resolved Imaging for mapping cortical-depth dependent responses,” Neuroimage, vol. 245, Dec. 2021.

[10] M. Chiew, S. M. Smith, P. J. Koopmans, N. N. Graedel, T. Blumensath, and K. L. Miller, “k-t FASTER: Acceleration of functional MRI data acquisition using low rank constraints,” Magn. Reson. Med., vol. 74, no. 2, pp. 353–364, Aug. 2015.

[11] H. Huang, X. Hu, Y. Zhao, M. Makkie, Q. Dong, S. Zhao, Y. Zhao, J. Han, L. Guo, and T. Liu, “Modeling task fMRI data via deep convolutional autoencoder,” IEEE Trans Med Imaging, vol. 37, no. 7, pp. 1551–1561, 2017.

[12] S. Parisot, S. I. Ktena, E. Ferrante, M. Lee, R. Guerrero, B. Glocker, and D. Rueckert, “Disease prediction using graph convolutional networks: application to autism spectrum disorder and Alzheimer’s disease,” Med Image Anal, vol. 48, pp. 117–130, 2018.

[13] L. Wang, K. Li, and X. P. Hu, “Graph convolutional network for fMRI analysis based on connectivity neighborhood,” Net Neurosci, vol. 5, no. 1, pp. 83–95, 2021.

[14] X. Li, Y. Zhou, N. Dvornek, M. Zhang, S. Gao, J. Zhuang, D. Scheinost, L. H. Staib, P. Ventola, and J. S. Duncan, “BrainGNN: Interpretable Brain Graph Neural Network for fMRI Analysis,” Med. Image Anal., vol. 74, p. 102233, Dec. 2021.

[15] J. Zhang, L. Zhou, L. Wang, M. Liu, and D. Shen, “Diffusion Kernel Attention Network for Brain Disorder Classification,” IEEE Trans. Med. Imaging, vol. 41, no. 10, pp. 2814–2827, Oct. 2022.

[16] H. A. Bedel, I. Sivgin, O. Dalmaz, S. U. H. Dar, and T. Çukur, “BolT: Fused window transformers for fMRI time series analysis,” Med. Image Anal., vol. 88, p. 102841, Aug. 2023.

[17] T. Matsui, M. Taki, T. Q. Pham, J. Chikazoe, and K. Jimura, “Counterfactual explanation of brain activity classifiers using image-to-image transfer by generative adversarial network,” Front Neuroinf, vol. 15, p. 79, 2022.

[18] B.-H. Kim and J. C. Ye, “Understanding graph isomorphism network for rs-fMRI functional connectivity analysis,” Front Neurosci, p. 630, 2020.

[19] T. Zhao, P. Tubiolo, T. Hagan, J. Williams, J. Snellenberg, and C. Huang, “Using interpretable deep learning on task fMRI data to understand brain regions related to working memory - a repeatability study,” in Annual Meeting of the ISMRM, 2023, p. 2712.

[20] B. J. Devereux, A. Clarke, and L. K. Tyler, “Integrated deep visual and semantic attractor neural networks predict fMRI pattern-information along the ventral object processing pathway,” Sci Rep, vol. 8, no. 1, p.10636, 2018.

[21] B. H. van der Velden, H. J. Kuijf, K. G. Gilhuijs, and M. A. Viergever, “Explainable artificial intelligence (XAI) in deep learning-based medical image analysis,” Med Image Anal, vol. 79, p. 102470, 2022.

[22] P. Sanchez, A. Kascenas, X. Liu, A. Q. O’Neil, and S. A. Tsaftaris, “What is healthy? generative counterfactual diffusion for lesion localization,” in DGM4MICCAI, 2022, pp. 34–44.

[23] T. Salimans and J. Ho, “Progressive distillation for fast sampling of diffusion models,” arXiv:2202.00512, 2022.

[24] D. C. Van Essen, S. M. Smith, D. M. Barch, T. E. Behrens, E. Yacoub, K. Ugurbil et al., “The WU-Minn human connectome project: an overview,” NeuroImage, vol. 80, pp. 62–79, 2013.

[25] L. Snoek, M. M. van der Miesen, T. Beemsterboer, A. van der Leij, A. Eigenhuis, and H. Steven Scholte, “The Amsterdam Open MRI Collection, a set of multimodal MRI datasets for individual difference analyses,” Sci Data, vol. 8, no. 1, pp. 1–23, 2021.

[26] G. Jeanneret, L. Simon, and F. Jurie, “Diffusion models for counterfactual explanations,” in ACCV, 2022, pp. 858–876.

[27] M. Augustin, V. Boreiko, F. Croce, and M. Hein, ‘Diffusion visual counterfactual explanations’, Advances in Neural Information Processing Systems, vol. 35, pp. 364–377, 2022.

[28] J. P. Cohen, R. Brooks, S. En, E. Zucker, A. Pareek, M. P. Lungren, and A. Chaudhari, “Gifsplanation via Latent Shift: A Simple Autoencoder Approach to Counterfactual Generation for Chest X-rays,” arXiv:2102.09475, 2021.

[29] S. J. Ritchie, S. R. Cox, X. Shen, M. V. Lombardo, L. M. Reus, C. Alloza et al., “Sex differences in the adult human brain: evidence from 5216 UK biobank participants,” Cereb Cortex, vol. 28, no. 8, pp. 2959–2975, 2018.

Figures

Figure 1: (a) DreaMR is a novel counterfactual method to explain a deep classifier that predicts cognitive state given a subject’s fMRI scan (i.e., BOLD responses across time and brain regions). (b) DreaMR trains a class-agnostic fMRI prior, distilled onto a specialized denoising network for each time fraction \(\mathbf{D}_{\theta_P}^{[f]}\), for efficient and high-fidelity sampling. (c) For input fMRI sample \(x_0\) mapped to cognitive state \(y_0\), DreaMR resamples \(x_0\) via the diffusion prior and classifier guidance to change the decision to \(\bar{y}_0 \neq y_0\).

Figure 2: DreaMR’s algorithm for counterfactual generation. Starting with a noise-added version of the original fMRI sample (i.e., a subject’s fMRI scan), reverse diffusion is performed across consecutive time fractions using fraction-specific denoising networks. During generation, classifier guidance is injected at each step to refine the sample so as to elicit the target counterfactual label for cognitive state from the classifier. Guidance is computed as the gradient of the log-posterior-probability evaluated for intermediate estimates of the denoised fMRI sample.


Figure 3: (a) Evaluation of explanation performance is listed for both time-series fMRI data and functional connectivity features derived from the time series. Results are shown for HCP-Task and ID1000 datasets. Performance is listed as mean\(\pm\)std for Proximity (Prox.), Sparsity (Spar.), and FID metrics across test sets. Lower metrics indicate higher performance. The top-performing method is marked in bold. (b) Inference times (Inf., msec) per counterfactual fMRI sample generation for competing methods.

Figure 4: Explanation methods produce maps denoting the importance of fMRI features in distinguishing between cognitive states. Maps produced by competing methods for a representative fMRI scan from HCP-Task recorded during a motor task. The original fMRI sample (left) and global average of fMRI samples across subjects (right) are also shown for the motor task. Separate counterfactual samples were generated to flip the class label from the motor onto each of six remaining tasks, and explanation maps were taken as the average difference between original and counterfactual samples.

Figure 5: DreaMR was used to explain a gender classifier that detects subject gender from rs-fMRI scans. Importance scores of functional connections in the brain were evaluated by producing explanation maps for each class. Functional connections with the top 5% of importance scores are visualized for female (top) and male (bottom) classes. Dots denote brain regions, bars denote connections, and their sizes indicate importance scores. LH/RH: left/right hemisphere; SomMot: somatomotor; DorsAttn: dorsal attention; SalventAttn: Salience/ventral attention.

Proc. Intl. Soc. Mag. Reson. Med. 32 (2024)
0382
DOI: https://doi.org/10.58530/2024/0382