J-RAS: Enhancing Medical Image Segmentation via Retrieval-Augmented Joint Training
Authors
Abstract
Image segmentation, the process of dividing images into meaningful regions, is critical in medical applications for accurate diagnosis, treatment planning, and disease monitoring. Although manual segmentation by healthcare professionals produces precise outcomes, it is time-consuming, costly, and prone to variability due to differences in human expertise. Artificial intelligence (AI)-based methods have been developed to address these limitations by automating segmentation tasks; however, they often require large, annotated datasets that are rarely available in practice and frequently struggle to generalize across diverse imaging conditions due to inter-patient variability and rare pathological cases. In this paper, we propose Joint Retrieval Augmented Segmentation (J-RAS), a joint training method for guided image segmentation that integrates a segmentation model with a retrieval model. Both models are jointly optimized, enabling the segmentation model to leverage retrieved image-mask pairs to enrich its anatomical understanding, while the retrieval model learns segmentation-relevant features beyond simple visual similarity. This joint optimization ensures that retrieval actively contributes meaningful contextual cues to guide boundary delineation, thereby enhancing the overall segmentation performance. We validate J-RAS across multiple segmentation backbones, including U-Net, TransUNet, SAM, and SegFormer, on two benchmark datasets: ACDC and M&Ms, demonstrating consistent improvements. For example, on the ACDC dataset, SegFormer without J-RAS achieves a mean Dice score of 0.8708$\pm$0.042 and a mean Hausdorff Distance (HD) of 1.8130$\pm$2.49, whereas with J-RAS, the performance improves substantially to a mean Dice score of 0.9115$\pm$0.031 and a mean HD of 1.1489$\pm$0.30. These results highlight the method's effectiveness and its generalizability across architectures and datasets.