Most deep-learning approaches require defining a loss function that is appropriate for the task. The choice of the loss function generally substantially affects the accuracy of the trained model and often requires hand-tuning. For example, some segmentation tasks work well with Dice loss while other work well with mean squared error (MSE). In this work we show how conditional adversarial network (cGAN) can be used to avoid defining a specialized loss function for each task and, instead, use a simple approach to achieve comparable or even superior results in context of segmentation of MRI images.
Several studies propose hand-crafted loss functions for a particular task with the argument that experimental observations show superior performance of the proposed hand-crafted loss function as compared to simpler loss functions1. This leads to task-specific hand-crafted loss functions to achieve good results. In this work, we investigate if there is another approach that is task agnostic and provide an alternative for loss function which can be used to achieve results as good as a hand-crafted loss function.
In this abstract, we target the specific task of MR image segmentation and its accompanying loss function. When looking at this task, many prior-art works converge to use the U-net2 as the network architecture and the dice3 coefficient and it’s variations as the loss function. The choice of dice over other loss functions has shown to be superior for image segmentation, particularly on tasks with class imbalance, where the marked pixels in the positive or foreground class are much less than the non-marked pixels in the negative or background class1. The question that remains is if this is really the best choice of loss function we have for training. Indeed, posterior works have proposed minor tweaks on the calculation of the dice coefficient, as well as other loss functions to improve segmentation results.
The method of this work is to use conditional generative adversarial networks4,6, also known as cGANs, to replace the need of a specialized loss function by an additional network, called discriminator, after segmentation and using simple loss functions. The generator in the cGAN does the segmentation and its output is labelled as the fake input for the discriminator, whereas the ground-truth segmentation mask is labelled as the valid input. Additionally, as shown in Fig.1, the discriminator is conditioned by the input MRI volume on both cases, that is the discriminator takes two inputs: a segmentation mask conditioned by MRI image (same as input given to the generator). The generator uses the U-net as the network architecture and the MSE (mean squared error) as the loss function. The discriminator uses a convolutional architecture, similar to VGG5, with a final fully connected layer and binary cross-entropy as the loss function. The training procedure interleaves between discriminator and generator.
We used T2-weighted brain MRI image from 847 clinical exams for the training. All studies were approved by an appropriate IRB. Atlas registration was used to generate ground truth segmentation. We augmented the data 25-fold for training. All testing was performed using independent data set with 98 exams.
To compare the cGAN trained for the segmentation task, we trained a single U-net network for segmentation with the same architecture and hyper-parameters as the U-net used in the cGAN’s generator. This single U-net does not rely on a discriminator network to improve its segmentation further, and it is trained using MSE loss function in one instance and dice loss function in another instance. Both instances are used for comparison in the results.
The three networks, cGAN with MSE, single U-net with MSE and U-net with dice, were trained to segment the hippocampus structure in brain MRI volumes. The results (Fig.2) show that the cGAN trained with MSE achieves a dice segmentation accuracy close to the single U-net trained with dice, that is cGAN hits 0.9022 dice accuracy while U-net hits 0.9070. The single U-net trained with MSE achieved 0.8675 dice accuracy, worse than the other two.
In addition to the dice metric, the cGAN improves the mean-absolute distance (MAD) error between prediction and ground-truth segmentation masks achieving 0.03 ± 0.10 millimeters (mm), when compared against the single U-net achieving 0.05 ± 0.18 mm. Both single U-net trained with MSE and dice achieved similar MAD error results.
The results show that the extra discriminator network in the cGAN is able to improve on metrics exogenous to the trained model, whereas the single U-net-with-dice performs well in the dice metric as expected while the U-net-with-MSE does not. In the MAD error metric, an exogenous metric for all of the networks, the cGAN fits its predicted mask considerably better than both the U-nets.
Of course, this help from the extra discriminator comes at a price. Training cGAN is much (~3X) slower than training U-net with dice, which is in turn somewhat slower than training U-net with MSE.
1. Salehi S.S.M., Erdogmus D., Gholipour A. "Tversky Loss Function for Image Segmentation Using 3D Fully Convolutional Deep Networks". Machine Learning in Medical Imaging. MLMI 2017. Lecture Notes in Computer Science, vol 10541.
2. Olaf Ronneberger, Philipp Fischer, Thomas Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation. Medical Image Computing and Computer-Assisted Intervention (MICCAI), Springer, LNCS, Vol. 9351: 234--241, 2015.
3. Fausto Milletari, Nassir Navab, Seyed-Ahmad Ahmadi. V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation. Fourth International Conference on 3D Vision (3DV), 565--571, 2016.
4. Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks. CVPR, 2017.
5. Simonyan, Karen, and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014).
6. Shin HC. et al. Medical Image Synthesis for Data Augmentation and Anonymization Using Generative Adversarial Networks. Simulation and Synthesis in Medical Imaging. SASHIMI 2018. Lecture Notes in Computer Science, vol 11037. 2018; arXiv:1807.10225