0690

Model Pruning for Generalisability Across MRI Imaging Sites
Nicola Katharine Dinsdale1, Mark Jenkinson2,3, and Ana IL Namburete4
1WIN / Oxford Machine Learning in NeuroImaging Lab, University of Oxford, Oxford, United Kingdom, 2WIN, University of Oxford, Oxford, United Kingdom, 3Australian Institute for Machine Learning, University of Adelaide, Adelaide, Australia, 4Computer Science, Oxford Machine Learning in NeuroImaging Lab, University of Oxford, Oxford, United Kingdom

Synopsis

We propose an algorithm to simultaneously prune and train CNNs, leading to networks which have increased generalisability across imaging sites. Through segmentation of data from the ABIDE dataset, we show that through reducing the number of parameters in the network throughout training, we are able to reduce model overfitting, creating a model which is more robust to expect image variations across scanners. We also introduce a novel Targeted Dropout algorithm, which aids the process of model pruning. We demonstrate the approach on a UNet architecture, the basis of nearly all segmentation approaches across medical imaging.

Introduction

In this work, we investigate model pruning to reduce deep learning model overfitting, by targeted removal of parameters during the training process. This is important since semantic segmentation in medical imaging is vital for understanding and monitoring disease progression, but manual segmentation is time consuming and difficult (e.g. due to ambiguous structure boundaries, reducing consistency of different raters). Therefore, reliable and accurate automatic segmentation methods are needed and deep learning-based methods are the state-of-the-art for medical image segmentation. Most methods are based on the UNet architecture1, but without very large numbers of manual segmentations for training, these models overfit to the training data, leading to poor performance, especially when applied to data from different distributions (e.g. data from different MRI scanners or acquired with different protocols). A primary cause of this overfitting is the vast number of parameters in such convolutional neural networks (CNNs), and our pruning approach aims to tackle this problem.

Methods

We propose the TrainPrune + Targeted Dropout (TD) algorithm (Fig. 1), a novel approach aiming to prune and train a model simultaneously. Consider the scenario where there is access to a training dataset $$$D =\{\mathbf{X}=\{ \mathbf{x}_1,\mathbf{x}_2...\mathbf{x}_N\},\mathbf{Y}=\{ \mathbf{y}_1,\mathbf{y}_2...\mathbf{y}_N\}\}$$$ where $$$\mathbf{x}_n\in {R}^{W\times H\times D\times1}$$$ represents an input image and $$$\mathbf{y}_n\in {R}^{W\times H\times D\times C}$$$ is the corresponding target segmentation with $$$C$$$ classes. A network is trained to predict the target segmentations from the input images, and is parameterized by $$${W}=\{(\mathbf{w}_1^1,b_1^1),(\mathbf{w}_1^2,b_1^2),...(\mathbf{w}_L^{C_L},b_L^{C_L})\}$$$ where $$$\mathbf{w}_m^n$$$ and $$$b_m^n$$$ correspond to the weights and biases for the nth filter of the mth layer, and $$$C_L$$$ is the number of channels in layer $$$L$$$. The weights and biases (represented jointly by $$$W$$$) are first randomly initialized and optimized to minimize $$$L(D|W)$$$. The aim of the pruning process is to refine the parameters, $$$W$$$, of the network to a smaller subset, $$$W'$$$, such that $$$L(D|W)\approx L(D|W')$$$.
We prune whole convolutional filters2 and, to decide which features to remove, we consider the feature activation maps, $$$\mathbf{z}_l^{(k)}=ReLU(\mathbf{z}_{(l - 1)}*\mathbf{w}_l^{(k)}+b_l^{(k)})$$$ where ReLU is the activation function used throughout the network. The weights with the smallest impact on the final prediction are those with the smallest average magnitudes3; thus the weights and biases corresponding to the smallest magnitude filter activations $$$\mathbf{z}_{l}^{(k)}$$$, averaged over the training data, are removed. To assess the magnitude of the filters, the L2-norm is used as it is computationally simple and provides stable performance: $$$\Theta_{L_2}(\mathbf{z}_l^{(k)})=\frac{1}{N}\sum^N_{i=1}{||\;\mathbf{z}_{l,i}^{(k)}\;||}_2$$$. These values are then normalised across the layer depth using L2 normalisation, $$$\hat{\Theta}_{L_2}(\mathbf{z}_l^{(k)})=\frac{\Theta_{L_2}(\mathbf{z}_l^{(k)})}{\sqrt{\sum_j (\Theta_{L_2}(\mathbf{z}_j^{(k)}))^2 }}$$$, to account for values being different at different scales and depths2.
To make the model robust via pruning, we introduce adaptive channelwise targeted dropout, where the dropout probability is based on the average index of the filters, ordered by magnitude at a given depth (Algorithm 1, Fig. 1). Thus, the filters with the lowest magnitude have the highest dropout probability, and so are more likely to be removed during training. As we are pruning channelwise, spatial dropout4 is applied.
To explore the effect of pruning on model generalisability, we consider T1 data from three sites in the ABIDE dataset5. As manual segmentation labels are not available, we created proxy labels using FSL FAST6, creating tissue segmentations (C=4). Data from the UM site was used for training; the performance was tested on a held-out test set (UM), and two unseen sites: Yale (similar age profile) and MPG (significantly different age profile). A standard 3D UNet architecture7 was considered (Fig. 2), as it is the architecture most commonly used in medical image segmentation. We compare our proposed TrainPruneTD approach to a Standard UNet approach. Across the experiments, 2 epochs of training were completed between each filter pruning, and $$$b_{drop}$$$ (the dropout probability scaling factor) was set to 0.05 in all cases. The number of filters in the first layer, f, was set to 4 unless otherwise specified, and all layers but the final, $$$l=18$$$ layer, were subject to pruning.

Results and Discussion

Figure 3 shows the results comparing TrainPruneTD to Standard UNet, for different initial model architecture sizes. It is clear that we can prune models greatly without decreasing the performance on the task. Critically, it clearly shows that pruning increases the performance on the two unseen sites compared to the Standard UNet, improving the generalisability of the model.
Figure 4 shows the results for $$$f=4$$$, comparing TrainPruneTD to the Standard UNet with and without augmentation8 during training, to simulate changes expected between sites. The TrainPruneTD model still outperforms the Standard UNet on the two unseen sites, indicating that although augmentation can increase generalisability for these models, our pruning approach does not explicitly require the identification of all the likely changes between sites. The same result was also seen comparing TrainPruneTD to standard dropout9, with TrainPruneTD generalising better across the sites. The performance of the TrainPruneTD and Standard UNet approaches for increasing levels of applied bias field8 (Fig. 5) further demonstrate improved generalisability as TrainPruneTD is more robust to such intensity-based changes.
These results clearly demonstrate that pruning during training can increase the generalisability of CNN based models and, thus, their applicability to multisite MRI data, and the flexible algorithm should be applicable across tasks and architectures.

Acknowledgements

ND is supported by the Engineering and Physical Sciences Research Council (EPSRC) and Medical Research Council (MRC) [grant number EP/L016052/1]. MJ is supported by the National Institute for Health Research (NIHR), Oxford Biomedical Research Centre (BRC), and this research was funded by the Wellcome Trust [215573/Z/19/Z]. The Wellcome Centre for Integrative Neuroimaging is supported by core funding from the Wellcome Trust [203139/Z/16/Z]. AN is grateful for support from the UK Royal Academy of Engineering under the Engineering for Development Research Fellowships scheme. The computational aspects of this research were supported by the Wellcome Trust Core Award [Grant Number 203141/Z/16/Z] and the NIHR Oxford BRC. The views expressed are those of the author(s) and not necessarily those of the NHS, the NIHR or the Department of Health.

References

1. Ronneberger O, Fischer P, and Brox T. U-Net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention (MICCAI), volume 9351 of LNCS, pages 234{241. Springer, 2015.

2. Molchanov P, Tyree S, Karras T, et al. Pruning convolutional neural networks for resource efficient transfer learning. In International Conference on Learning Representations, 11 2016.

3. Le Cun Y, Denker J, and Solla S. Optimal brain damage. In Advances in Neural Information Processing Systems 2, page 598–605, San Francisco, CA, USA, 1990. Morgan Kaufmann Publishers Inc. ISBN 1558601007.

4. Tompson J, Goroshin R, Jain A et al Efficient object localization using convolutional networks. In 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 648–656, 2015. doi: 10.1109/CVPR.2015.7298664.

5. Di Martino A, Yan Chao-Gan, Li Qingyang et al. The Autism Brain Imaging Data Exchange: Towards large-scale evaluation of the intrinsic brain architecture in Autism. Molecular psychiatry, 19, 06 2013. doi: 10.1038/mp.2013.78.

6. Zhang YY, Brady M, and Smith S. Segmentation of brain MR images through a hidden markov random field model and the expectation maximization algorithm. IEEE Transactions on Medical Imaging, 20:45–57, 02 2001. doi:10.1109/42.906424.

7. Pérez-García F, Sparks R, and Ourselin S. TorchIO: a Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning. arXiv:2003.04696 March 2020.

8. Srivastava N, Hinton G, Krizhevsky A et al. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(56):

Figures

The proposed TrainPrune + Targeted Dropout training regime is shown by the flow chart, with the targeted dropout being shown in Algorithm 1. Across the experiments, two training epochs were used, bdrop and was empirically set to 0.05. The training procedure was iterated until no filters remained that could be pruned without breaking the model. The best performing model was then chosen using the performance on the validation data.

A 3D UNet was used in all experiments as the majority of segmentation methods use this or a derivative of it. Our network has 18 layers of convolutional filters where whole filters get pruned. The final layer cannot be pruned as each channel is an output class. The input data are T1-weighted brain MRIs, and the target segmentation labels are white matter, grey matter, CSF and background (C=4), created using FSL Fast. The starting number of filters, f, is set to 4 unless stated otherwise.

The TrainPruneTD model was trained with f=(2,4,8) and compared to the best Standard UNet (f=4). Results (UM training data) show Dice values (UM test data) as parameters are removed. TrainPruneTD performance holds even with far fewer parameters. For f=2, performance was poor but pruned larger models, with the same number of parameters, did well, showing the benefits of pruning. Results from unseen sites (MPG + Yale) show increases in performance/generalisability when pruned heavily.

We compared TrainPruneTD to the Standard UNet (both f=4) with on the fly augmentations designed to simulate scanner differences. Standard UNet performance on UM testing data decreased, but increased for the unseen sites. TrainPruneTD also performed better on unseen sites but maintained performance on the training site (UM). This shows that pruning reduces overfitting better across all sites. Dropout results showed the same pattern, with only moderate levels reducing performance on UM data.

Increasing levels of bias field were applied to the testing images (UM) showing substantial performance decreases for the Standard UNet, but near constant performance for TrainPruneTD. Lower panels show example segmentations, with performance in box plots shown by correspondingly coloured points. Fine segmentation details were lost as performance decreased, showing the extent of overfitting in standard training approaches, and the increase in generalisability that TrainPruneTD brings.

Proc. Intl. Soc. Mag. Reson. Med. 30 (2022)
0690
DOI: https://doi.org/10.58530/2022/0690