Abstract
In this white paper, we present an AI-based approach to brain MRI segmentation utilizing the U-Net architecture. Our solution enables precise identification of regions of interest in MRI scans, enhancing diagnostic capabilities in medical imaging. This paper outlines the design, architecture, and implementation of the model and showcases how it can be integrated into real-world applications using fast and efficient inference pipelines.
Introduction
Medical imaging plays a crucial role in diagnosing diseases, particularly brain abnormalities. Manual segmentation of MRI scans is time-consuming and prone to errors. AI-driven segmentation models, such as U-Net, have proven effective in automating this process with high accuracy. This paper introduces a machine learning-based solution using U-Net for brain MRI segmentation.
Problem Statement: Manual segmentation of MRI scans is labor-intensive and error-prone, necessitating an automated approach.
Objective: Develop a robust and efficient AI model to segment brain MRI scans accurately.
Methodology
Dataset Preprocessing
The input MRI images are preprocessed to ensure uniformity and enhance the segmentation accuracy. This involves resizing the images to a standard resolution (256x256 pixels) to match the input size expected by the U-Net model.
# Resizing images to 256x256
def resize_image(input_image_path):
img = Image.open(input_image_path)
resized_img = img.resize(target_size)
resized_img.save(input_image_path)
return input_image_path
U-Net Model for Brain Segmentation
The U-Net architecture was chosen for this project due to its demonstrated effectiveness in biomedical image segmentation. The model architecture includes:
Encoder-Decoder Structure: It captures features from MRI images at different scales and restores image details for precise segmentation.
Skip Connections: Preserve spatial information from encoder layers to decoder layers, improving segmentation accuracy.
The model is pre-trained on brain MRI datasets, and its weights are loaded using the PyTorch framework:
model = torch.hub.load(
'arithescientisttt/unet_brainsegmentation',
'unet',
in_channels=3,
out_channels=1,
init_features=32,
pretrained=True,
)
Image Processing and Inference
After preprocessing the MRI images, we segment the images by sampling them into patches, passing them through the model, and aggregating the predictions.
Grid Sampler and Aggregator: Images are split into smaller patches, processed individually, and combined to create the final segmentation map.
Inference Process: Each image patch is processed by the U-Net model to predict the segmented regions.
# Sampling image patches
grid_sampler = tio.inference.GridSampler(subject_preprocessed, patch_size, patch_overlap)
aggregator = tio.inference.GridAggregator(grid_sampler)
Results and Evaluation
The output from the U-Net model is a segmentation mask, which is visualized by overlaying it on the original MRI scan. This allows for easy interpretation of the segmented areas.
# Plotting MRI and segmentation mask
fig, ax = plt.subplots(figsize=(images.shape[1] / 100, images.shape[0] / 100), dpi=100)
ax.imshow(images)
ax.imshow(mask, cmap='gray', alpha=0.5)
We tested the model on several MRI images, with the output confirming the model's ability to segment key areas in the brain effectively. The accuracy of the model can be quantified using Dice similarity coefficients or Jaccard indices, which will be covered in future work.
Integration and Application
Gradio-Based Interface
To allow easy interaction with the segmentation model, a web-based interface was developed using Gradio. Users can upload MRI images, and the system will return the segmented output. Gradio was selected for its simplicity and ability to provide immediate feedback to users.
interface = gr.Interface(inference, gr.Image(label="input image", type="filepath"), gr.Plot(),
description=description, title=title, examples=examples)
FastAPI Deployment
The entire solution can be deployed as a web application using FastAPI, allowing integration with other systems and APIs in real-time. This makes it suitable for deployment in cloud environments or as part of a hospital's imaging workflow.
app = FastAPI()
app = gr.mount_gradio_app(app, interface, path="/")
Conclusion and Future Work
This project demonstrates the effectiveness of AI-based approaches, specifically U-Net, in automating brain MRI segmentation. The ability to accurately identify regions of interest in MRI scans has significant implications for medical diagnostics. Future work will focus on improving the model’s accuracy, adding additional training data, and extending the application to other types of medical imaging.
References
Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. arXiv preprint arXiv:1505.04597.
PyTorch Hub Documentation: https://pytorch.org/hub/
TorchIO Documentation: https://torchio.readthedocs.io/