Post

Decoding X-ray Hand Bones | Semantic Segmentation

Decoding X-ray Hand Bones | Semantic Segmentation

1. Introduction

Bone Segmentation is one of the most important applications in Artificial Intelligence, using deep learning technologies to find and segment individual bone.

Thanks to boostcamp AI Tech, our team can work on a project to segment and analyze hand bones using semantic segmentation models for a medical image dataset, generally hard to obtain.

As a team leader on this project, I contributed to managing our project schedule, creating a visualization tool to analyze dataset or experiment results, conducting model experiment based on relevant libraries, and carrying out ensemble to improve the final result of our models.


2. Exploratory Data Analysis (EDA)

2.1. Visualization Tool

A visualization tool is introduced using Streamlit, an open-source framework for AI/ML engineers. To analyze the image data properly, we needed features as follows.

  • Compare (image vs label, image vs model result, model result vs label)
  • Toggle the visibility of labels by class
  • Upload experiment result (in csv format)

Considering all the requirements above, I made web demo for data(images & labels) visualization as below. You can find the code here.

streamlit Visualization Web Demo

2.2. What we found

Through our web visualizer, we found a few things to consider in this project. First, the pixels on the wrist bone are overlapped compared to the finger bones. It would be effective to choose a model that works for medical images. Next, some bones such as fingertip or wrist are too small to identify with naked eyes. Enlarging the image could be a solution. Also, small number of data is also one of the limitations to improve the performance of our model.


3. Model Experiments

3.1. Segmentation models

Segmentation has already been developed as a relatively traditional deep learning technology in AI, which changes so quickly. Fortunately, there is a library, called Segmentation Models, which supports various segmentation models and encoders. To conduct experiments quickly, we adopted this library. The interface for model application looks like as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import segmentation_models_pytorch as smp


class SmpModel:
    def __init__(self, model_config, num_classes):
        self.arch = model_config['arch']
        self.encoder_name = model_config['encoder_name']
        self.encoder_weights = model_config['encoder_weights']
        self.in_channels = model_config['in_channels']
        self.num_classes = num_classes

    def get_model(self):
        self.model = smp.create_model(
            arch = self.arch,
            encoder_name = self.encoder_name,
            encoder_weights = self.encoder_weights,
            in_channels = self.in_channels,
            classes = self.num_classes,
        )
        
        return self.model 

Considering the characteristics of our dataset, I experimented with medical-originated models (UNet, MANet) or the latest models (UPerNet, SegFormer).

  • UNet Series
ModelBackboneDice
UNetResNet500.9474
UNet++ResNet500.9538
UNEt++ResNet1010.9517
UNEt++Efficientnet-b50.9511
UNEt++GerNet-L0.9513
  • MANet : Due to limited time, the training was executed with smaller image size(512 ➡️ 256) and less epochs (100 ➡️ 50)
ModelBackboneDice
UNet++ResNet500.9101
MANetResNet500.8989
MANetRes2Net50_26w_4s0.8957
MANetResNest50d_4s2x40d0.9026
  • UPerNet
ModelBackboneDice
UperNetResNet1010.9479
UperNetswin transformer0.9501
  • SegFormer
ModelBackboneDice
SegFormermit-b00.9610

3.2. Augmentation

Furthermore, we performed off-line augmentation by cropping and enlarging small bones or wrist in order to solve small bone and overlapping problems.

  • Size

As the below table shows, there was a significant performance improvement when the size increased. However, it was impossible to enlarge image size over 1024 due to memory limitations.

Image SizeDice
256 x 2560.8431
512 x 5120.8575
1024 x 10240.9644
  • Off-line augmentation

Through cropping wrist or/and fingertip images, we added more images into our dataset. The result shows that the performance was improved, which seems to be due to 1) an increase in the amount of data itself, 2) additional learning for the part that lacked training.

ClassOriginalWristFingertipWrist + Fingertip
Lunate0.91030.93660.93400.9316
Trapezoid0.87050.89510.89170.9005
Pisiform0.82320.87500.86890.8864
Trapezium0.91100.92210.90820.9178
finger-160.85640.90600.90540.9002


4. Ensemble

In data science, ensemble learning is a technique that combines multiple models to make more accurate prediction. Hard voting-based ensemble was introduced to improve our model performance. Combining the prediction results of several models for each class pixels, it is reflected in the final result when vote result exceed threshold. The way how I coded is as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import numpy as np
import pandas as pd 
from tqdm import tqdm

def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(height * width, dtype=np.uint8)
    
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    
    return img.reshape(height, width)

def encode_mask_to_rle(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]

    return ' '.join(str(x) for x in runs)

def csv_ensemble(csv_paths, save_dir, threshold): 

    csv_data = []
    for path in csv_paths:
        data = pd.read_csv(path)
        csv_data.append(data)

    file_num = len(csv_data)
    filename_and_class = []
    rles = []

    print(f"Model Number : {file_num}, threshold: {threshold}")  

    pivot_df = csv_data[0]

    for _, row  in tqdm(pivot_df.iterrows(), total=pivot_df.shape[0]):
        img = row['image_name']
        cls = row['class']
        model_rles = []

        for data in csv_data:
            rle_value = data[(data['image_name'] == img) & (data['class'] == cls)]['rle'].values
            if len(rle_value) == 0 or pd.isna(rle_value[0]): 
                print(f"No RLE Data : Image {img}, Class {cls}")
                model_rles.append(np.zeros((2048, 2048), dtype=np.uint8))  
                continue
            model_rles.append(decode_rle_to_mask(str(rle_value[0]), 2048, 2048))
     
        image = np.zeros((2048,2048))

        for model in model_rles:
            image += model
        
        image[image <= threshold] = 0
        image[image > threshold] = 1

        result_image = image

        rles.append(encode_mask_to_rle(result_image))
        filename_and_class.append(f"{cls}_{img}")

    classes, filename = zip(*[x.split("_") for x in filename_and_class])
    image_name = [os.path.basename(f) for f in filename]

    df = pd.DataFrame({
        "image_name": image_name,
        "class": classes,
        "rle": rles,
    })

    df.to_csv(save_dir, index=False)


5. Conclusion

5.1. Result

Based on the experiment results, we concluded and trained our final models for 100 epochs.

  • Unet++ : Backbone(ResNet50), Image Size(1024)
  • DeepLabV3 : Backbone(EfficientNet-b8), Image Size(1024)
  • UperNet : Backbone(Swin Transformer), Image Size(1536)
  • SegFormer : Backbone(mit-b0), Image Size(1024)

With the results, we could derive the best performance compared to the baseline. (3.4%🔺)

ModelThresholdDice
Ensemble (w/o Best)10.9676
Ensemble (w/o Best)20.9689
Ensemble (w/ Best)10.9696

This project is a part of a project held at boostcamp AI Tech, managed by NAVER Connect Foundation. The dataset is only used for boostcamp education and not allowed to take outside, so this post includes very limited image. Furthermore, it is available to find more detailed code for this project in my GitHub.

This post is licensed under CC BY 4.0 by the author.

Trending Tags