Extending Binary Image Segmentation to Multi-Class Image Segmentation

Darshita Jain
5 min readMay 7, 2021

During my MTech, I was trying to solve the problem of detecting over-exposed, under-exposed, and properly exposed regions in a given input image. I addressed this problem using multi-class image segmentation. Many articles and videos were available to understand the concept and implementation of binary image segmentation but how to perform multi-class image segmentation was not available. Through this blog, I would like to help you by sharing my experience.

This article will discuss what changes in the binary image segmentation code you need to make to perform multi-class image segmentation.

Before starting with this blog, I would strongly recommend you to go through the below blog. I have used this blog to understand how to use transfer learning to solve binary image segmentation. It should be noted that I have used the author’s code as a template and made the required changes.

The four main changes that we need to perform are in the-

  • Input segmentation mask format
  • Loss function
  • Evaluation metric
  • Number of output channels of the model

Let us look at each of the above points in detail.

  1. Input segmentation mask format- For the task of image segmentation along with the input image we also need a segmentation mask. Segmentation mask denotes which class a particular pixel of an image belongs to. Segmentation can be considered as a dense classification task because we need to classify each and every pixel into a predefined class. There are different types of segmentation masks like a one-hot encoded mask, RGB segmentation mask, class-indexed mask, etc. In this blog, I will focus on class indexed mask as it is compatible with many existing segmentation models as well as the loss function that we will discuss next. The below image depicts a class-indexed mask. In this case, different classes are given class ids for example Person is assigned id 1, Purse is assigned id 2, and so on. Then each pixel of the input image is classified into one of the given classes and is assigned the corresponding class id. The resulting segmentation mask will be a grayscale image.
Example of class indexed segmentation mask. Image Source- https://www.jeremyjordan.me/semantic-segmentation/

For annotating your input segmentation data you can use the LabelMe annotation tool. It is an open-source software developed by MIT which is easy to use.

An example of annotated segmentation mask created using the LabelMe tool.

2. Loss function- Categorical cross-entropy loss is generally used in the case of semantic segmentation. In semantic segmentation problems, we need to assign class ids to each pixel of the image. Note that categorical cross-entropy loss is used for single label categorization, i.e. when one pixel can belong to only one category.

Generally, the softmax activation function is applied before calculating the cross-entropy loss. The softmax provides us with the probability distribution of a pixel belonging to different classes. Cross entropy then compares the distribution of predictions with the actual predictions.

To include this change you need to modify the variable criterion present in the file main.py (available in the above GitHub repository)

3. Evaluation Metrics- The two widely used metrics that are used in the case of semantic segmentation, Dice score (F1 score) and Jaccard score (Intersection over Union score). The Dice score is calculated as two times the overlap between the predicted segmentation mask and the Ground Truth mask divided by the total pixels in both the masks. It can be visualized as shown in the figure.

Visualization of Dice score

Another commonly used metric is the Jaccard score or the IoU score. It is a useful metric calculated by dividing the overlap between the Ground Truth segmentation mask and the predicted mask by the total pixels in the Ground Truth and the predicted masks. It can be visualized as shown in Figure

Visualization of Jaccard score

Note that these two metrics work well in imbalanced class data, i.e., when the number of pixels in one class dominates the number of pixels present in other classes in an image.

To include this change you need to modify the variable metrics present in the file main.py (available in the above GitHub repository)

Don’t forget to include the required import statement that is-

4. Number of output channels of the model- This will depend on the number of classes we have to predict. In my case, there were three classes- over-exposed, under-exposed, and properly exposed regions therefore the number of output channels was set to 3. This can be modified in the file model.py (available in the above GitHub repository).

Some useful resources

If you are interested you can have a look at my research paper- “Deep Over and Under Exposed Region Detection” which got accepted in the 5th IAPR International Conference on Computer Vision and Image Processing (CVIP) 2020.

That’s it for this blog. I hope you find it useful.

--

--

Darshita Jain

Researcher at Tata Research Development and Design Centre| MTech CSE | IIT Gandhinagar