Zhiyang Fu1, Maria I Altbach2, Diego R Martin2, and Ali Bilgin1,2,3
1Electrical and Computer Engineering, University of Arizona, Tucson, AZ, United States, 2Department of Medical Imaging, University of Arizona, Tucson, AZ, United States, 3Biomedical Engineering, University of Arizona, Tucson, AZ, United States
Synopsis
MR images are often reconstructed first and then used for medical image analysis tasks such as segmentation or classification. This sequential procedure can compromise the performance of the image analysis task. In this work, we propose a multi-task learning framework that jointly reconstructs underlying images and detects multiple sclerosis lesions. This framework outperforms the
conventional sequential processing pipeline. We also introduce a multi-objective optimization as an effective and automated approach to balance
the trade-off among multi-task losses. Experimental results suggest that taking into account subsequent detection tasks during
image reconstruction may lead to enhanced detection performance.
Introduction
Image reconstruction and detection tasks in MRI are conventionally
considered in a sequential manner; Images are first reconstructed without
consideration of the subsequent detection task, and the detection task is
carried out on the reconstructed images. With the recent success of deep
learning methods in both image reconstruction and detection problems, we
investigate whether a joint image estimation (i.e. reconstruction) and
detection framework can offer improved estimation and/or detection performance.
We propose a deep learning framework for multi-task learning (MTL)1–4 where the learning problem is
formulated as multi-objective optimization (MOO).4,5 We show that multi-objective learning yields solutions
superior to those obtained from per-task learning or conventional sequential
processing approaches.Methods
Fig.1a illustrates the conventional processing pipeline where
reconstructed images are subsequently used in detection tasks. Fig.1b
and 1c represent alternative reconstruction and detection only
pipelines. The proposed MTL approach is shown in Fig.1d. MTL networks
generally compose of a shared module1,6 parameterized by $$$\theta^{sh}$$$ followed by task-specific modules parameterized by
$$$\theta^{t}$$$. Multi-objective learning combines the losses of estimation ($$$\mathcal{L}_e$$$) and
detection ($$$\mathcal{L}_d$$$) with a task weighting parameter $$$\lambda$$$, that is
$$\min_{\theta^{sh},\theta^{e},\theta^{d},\lambda}\frac{1}{N}\sum_i \lambda\mathcal{L}_e(f^e(x_i, \theta^{sh},\theta^e),y^e_i)+ (1-\lambda)\mathcal{L}_d(f^d(x_i, \theta^{sh},\theta^d),y^d_i), \; s.t. \, 0\leq\lambda\leq1,$$
where $$$f^e(\cdot)$$$ and $$$f^d(\cdot)$$$ denote the network outputs, and $$$(x_i, y_i^e)$$$ and
$$$(x_i, y_i^d)$$$ represent the supervised training pairs, for
estimation and detection task, respectively. In our MOO, network parameters and
task weighting are updated using alternating optimization. At each
iteration, the task weighting $$$\lambda$$$ is estimated with an analytical
solution4 and then
fixed for a normal update of the network parameters. As an alternative to MOO, we also train networks using several fixed
values of $$$\lambda$$$ for
comparison.
Our MTL network
architecture was derived from enhanced residual network (ERN)7
and dilated residual network (DRN).8
Observing that similar residual blocks were used in the preceding
layers of two networks, we adopted 4 residual blocks as the shared
module, where each residual block consists of
“Conv-ReLU-Conv-SkipConnection”. Additional 12 residual blocks of
ERN were attached to the shared module for the estimation task.
Similarly, the remaining layers of DRN-C-268
are attached for the detection task. We use $$$\ell_1$$$ norm and cross entropy
as the losses for estimation and detection, respectively. Class
weighting was applied based on the class frequency to account for
imbalance of training data.
T2
FLAIR images of the brain with multiple sclerosis (MS) lesions were
generated per-subject using the anatomical models of 20 normal
subjects from BrainWeb9
and 2229 individual lesion volume labels from MICCAI10,11
database. T2 FLAIR sequence was simulated with parameters TE/TR/TI=114/8000/1800 ms. Non-uniformity was added to all the tissue types
using a scaling factor that follows uniform distribution U(0.9,1.1). Radial
data acquisition using golden angle sampling with acceleration
factors (AFs) of 2, 5, 8, and 11 was used in the experiments and
phase modulation as well as k-space noise (SNR=32dB) were included.
The first 18 subjects were used for training (2760 training, 552
validation slices) and the last 2 subjects (364 slices) were used for
testing. For each acceleration factor, our MTL network was trained
using the multi-objective optimization and with
fixed $$$\lambda=0,0.025,0.05,0.2,0.5,0.8,0.95,0.975,1$$$ (referred to as "grid search"). Note that $$$\lambda=0$$$ and $$$\lambda=1$$$ correspond to detection only
and estimation only learning, respectively. To better understand the
trade-off between the two individual tasks, we also trained networks for
one of the tasks and then finetuned them on the other task for 16
epochs with a small learning rate (1e-5). Normalized root
mean-squared error and average Dice were used as estimation and
detection metrics, respectively.Results and Discussion
Fig.2 shows the Estimation and Detection Information Trade-off
(EDIT12) plot at four AFs. The figure includes
results obtained using the estimation/detection only approaches, the
conventional pipeline, as well as the grid search (fixed $$$\lambda$$$) and MOO techniques. Note that the
MTL approach, with either grid search or MOO techniques, leads to large
improvements in detection performance over estimation/detection only approaches
as well as the conventional sequential approach. Fig.3 shows the corresponding
estimation and detection results on a representative slice. On the detection
task, MTL networks (grid search or MOO) present a more accurate lesion
prediction than the single task networks as indicated by the Dice metrics. On
the estimation task, three methods provide similar reconstruction quality for
the same AF. Fig.4 shows the EDIT plot for AF=8 with finetuning upon single
task networks. It can be observed that the detection performance of the
estimation only network is improved through finetuning and can even surpass the
performance of the detection only network. Fig.5 illustrates the evolution of
detection and estimation results when finetuning on single task networks. As expected, finetuning using a detection loss
on the estimation only network significantly improves detection performance.
Similarly, finetuning using an estimation loss on the detection only network
improves visual quality of the reconstructed image.Conclusion
We presented a multi-task learning framework, which outperforms the
conventional sequential processing pipeline. We also demonstrated that the
multi-objective optimization is an effective and automated approach to balance
the trade-off among multi-task losses. The
results suggest that taking into account subsequent detection tasks during
image reconstruction may lead to enhanced detection performance.Acknowledgements
The authors would like to acknowledge support from Arizona Health
Sciences Center Translational Imaging Program Project Stimulus, BIO5 Team
Scholar's Program, and Technology and Research Initiative Fund (TRIF)
Improving Health Initiative.References
1. Caruana
RA. Multitask Learning: A Knowledge-Based Source of Inductive Bias. In: Machine
Learning Proceedings 1993. Elsevier; 1993:41-48.
2. Cipolla
R, Gal Y, Kendall A. Multi-task Learning Using Uncertainty to Weigh Losses for
Scene Geometry and Semantics. In: 2018 IEEE/CVF Conference on Computer
Vision and Pattern Recognition. Salt Lake City, UT, USA: IEEE;
2018:7482-7491.
3. Chen
Z, Badrinarayanan V, Lee C-Y, Rabinovich A. GradNorm: Gradient Normalization
for Adaptive Loss Balancing in Deep Multitask Networks.; 2018.
4. Sener
O, Koltun V. Multi-Task Learning as Multi-Objective Optimization. In: Bengio S,
Wallach H, Larochelle H, Grauman K, Cesa-Bianchi N, Garnett R, eds. Advances
in Neural Information Processing Systems 31. Curran Associates, Inc.;
2018:527–538.
5. Désidéri
J-A. Multiple-gradient descent algorithm (MGDA) for multiobjective
optimization. Comptes Rendus Mathematique. 2012;350(5-6):313-318.
6. Guo
M, Haque A, Huang D-A, Yeung S, Fei-Fei L. Dynamic Task Prioritization for
Multitask Learning. In: Ferrari V, Hebert M, Sminchisescu C, Weiss Y, eds. Computer
Vision – ECCV 2018. Vol 11220. Cham: Springer International Publishing;
2018:282-299.
7. Lim
B, Son S, Kim H, Nah S, Lee KM. Enhanced Deep Residual Networks for Single
Image Super-Resolution. IEEE; 2017.
8. Yu
F, Koltun V, Funkhouser T. Dilated Residual Networks. In: 2017 IEEE
Conference on Computer Vision and Pattern Recognition (CVPR). Honolulu, HI:
IEEE; 2017:636-644.
9. Collins
DL, Zijdenbos AP, Kollokian V, et al. Design and construction of a realistic
digital brain phantom. IEEE Transactions on Medical Imaging.
1998;17(3):463–468.
10. Pieper
S, Halle M, Kikinis R. 3D SLICER. 2004:632–5.
11. Warfield
SK, Zou KH, Wells WM. Simultaneous Truth and Performance Level Estimation
(STAPLE): An Algorithm for the Validation of Image Segmentation. IEEE Trans
Med Imaging. 2004;23(7):903-921.
12. Cushing JB, Clarkson EW, Mandava S,
Bilgin A. Estimation and detection information trade-off for x-ray system
optimization. In: Ashok A, Neifeld MA, Gehm ME, eds. Baltimore, Maryland, United
States; 2016:98470U.