Image Segmentation Based on U-Net Architecture. Implementation & Code Example for Buildings Recognition

It is a process of dividing images into groups of pixels (segments) that match the objects in the image. This means that we basically classify each pixel in the image as belonging to a specific object or not (of course, this object can only be a background).

At this point, image segmentation may feel like identifying objects with additional steps where instead of just finding the objects we want, we have to find each of their pixels.

But that is not the case. Using image segmentation allows us to simplify the image to a “digestible” representation by other algorithms (binary mask instead of RGB image) while preserving information about the sizes, shapes and spatial position of objects, which is helpful in cases like medical imaging or self-driving cars (but About that later).

Image segmentation applications

  • Medical imaging – that is, segmentation of healthy and tumor tissue on histopathological images allows rapid assessment of the patient’s cancer stage or automatic detection and marking of bone fractures in RTG images
  • Self-driving cars – Autonomous vehicles can only be as good as they are perceived, and this requires a system that can find accurate locations of pedestrians, obstacles, road signs and other vehicles in pictures
  • Aerial images for civil engineering and / or agriculture – Image segmentation can help assess the progress of construction at construction sites or monitor crop health by analyzing aerial images

Neural networks for image segmentation

Although some interesting non-learning algorithms have been used for image segmentation, such as the Otsu method or the Watershed algorithm, most of today’s real segmentation problems are solved by training neural networks [NNs]. More precisely, convolutional neural networks [CNNs]. And even more specifically, a kind of CNN encoder-decoder. One of the salient architectures of such networks is [U-Net]() (The name comes from its form) (Picture 1). U-net consists of 2 parts:

  • Encoder (left part of “U”) – Image encoder for abstract representation of image features by applying a sequence of convolutional blocks that gradually reduce the height and width of the representation but an increasing number of channels corresponding to the image features.
  • Decoder (right part of “U”) – Decodes the image representation into a binary mask by applying a sequence of upward twists (not the same as deconvolution) that gradually increase the height and width of the representation to the size of the original image and reduce the number of channels to the number of classes we segment

In addition, U-Net implements skip connections that connect appropriate levels of encoder and decoder. They allow the model not to “lose” features extracted by previous encoder blocks, which increases segmentation performance.

Implementation

As for the demo part of the article, we will implement and trust the U-Net architecture using PyTorch and segmentation_models_pytorch segment Based on aerial images.

Prepare

For our project, we will need to install the following packages:

  • PyTorch
  • PyTorch-Segmentation-Models
  • PIL
  • opencv
  • Albums
  • Pandas
  • matplotlib

For data, we need to download and unload It Data array.

Once we have everything in place, we can start importing everything we will need during implementation

import os
from typing import Tuple, List

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import albumentations as album
import segmentation_models_pytorch as smp

In addition, we must indicate on which loyal device U-Net:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

I highly recommend using ‘cuda’. Otherwise, the training will take a long time. If you do not have access to a GPU, consider using a free device Collaboration or Kegel

Finally, we will define a path to our database and load a CSV with a map for our classes:

DATA_DIR = "/kaggle/input/massachusetts-buildings-dataset/tiff"
class_dict = pd.read_csv("/kaggle/input/massachusetts-buildings-dataset/label_class_dict.csv", index_col=0)

Data preparation

One of the basic processes we can easily implement to prevent over-matching of models is to increase training data. In addition, because CNNs are agnostic of input shapes, meaning that a CNN trained with one shape can be used for different shapes, we can train the model on a set of smaller pieces of images to make better use of computing resources. Of course, we want to avoid these magnifications for verification in order to get a clear view of our model performance.

def get_training_augmentation():
train_transform = [
album.RandomCrop(height=256, width=256, always_apply=True),
album.OneOf(
[
album.HorizontalFlip(p=1),
album.VerticalFlip(p=1),
album.RandomRotate90(p=1),
],
p=0.75,
),
]
return album.Compose(train_transform)

def get_validation_augmentation():
test_transform = [
album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0)
]
return album.Compose(test_transform)

We can now apply a Data array A class that will load images and masks from the disk, convert RGB masks to binary masks and start the specified magnifications.

def encode_mask(mask, df_labels):
channels = []
for c in df_labels.index:
rgb = torch.tensor(df_labels.loc[c].to_list()).view(-1, 1, 1)
_mask = torch.all(mask == rgb, dim=0).float()
channels.append(_mask)
return torch.stack(channels, dim=0)

class BuildingDataset(Dataset):

def __init__(self, split: str, data_dir: str, df_labels: pd.DataFrame, augmentation=None):
self.img_dir = os.path.join(data_dir, split)
self.mask_dir = os.path.join(data_dir, split + "_labels")

self.sample_names = os.listdir(self.img_dir)

self.df_labels = df_labels
self.augmentation = augmentation

def __len__(self):
return len(self.sample_names)

def __getitem__(self, idx):
sample_name = self.sample_names[idx]
img_path = os.path.join(self.img_dir, sample_name)
mask_path = os.path.join(self.mask_dir, sample_name[:-1])

img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) / 255
mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2RGB)
if self.augmentation:
sample = self.augmentation(image=img, mask=mask)
img, mask = sample['image'], sample['mask']

img = torch.tensor(img.transpose(2, 0, 1).astype('float32'))
mask = torch.tensor(mask.transpose(2, 0, 1).astype('float32'))
mask = encode_mask(mask, self.df_labels)

return img, mask
train_dataset = BuildingDataset("train", DATA_DIR, class_dict, augmentation=get_training_augmentation())
val_dataset = BuildingDataset("val", DATA_DIR, class_dict, augmentation=get_validation_augmentation())
test_dataset = BuildingDataset("test", DATA_DIR, class_dict, augmentation=get_validation_augmentation())

Below you will see some sample training pictures and their masks.

Assembling pictures and masks

Model application

Now we can apply a model. We will start with the implementation of a single convolutional block that will address the core of the calculations.

class ConvBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(ConvBlock, self).__init__()
_in_channels = [in_channels] + [out_channels] * (n_convs - 2)
self.model = nn.Sequential(*[self._get_sngle_block(ic, out_channels, kernel_size,
padding) for ic in _in_channels])

def _get_sngle_block(self, in_channels: int, out_channels: int, kernel_size: int = 3,
padding: int = 1) -> nn.Sequential:
return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

Here we can apply a Block Down that will apply a convolution and return 2 tensors: Basic convulsion output after a maximum pool to reduce size and skip a connection tensor equal in size to the input size.

class DownBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(DownBlock, self).__init__()
self.conv = ConvBlock(in_channels, out_channels, n_convs, kernel_size, padding)
self.down_sample = nn.MaxPool2d(2)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
skipped_x = self.conv(x)
x = self.down_sample(skipped_x)
return x, skipped_x

Next in line is the Up block that will get 2 tensors: Output of the previous block that the Up block will upload a sample and take out of a matching Down block. The block above will thread these inputs and apply a convolution to them.

class UpBlock(nn.Module):

def __init__(self, in_channels: int, out_channels: int, n_convs: int = 2,
kernel_size: int = 3, padding: int = 1) -> None:
super(UpBlock, self).__init__()
self.up_sample = nn.ConvTranspose2d(in_channels-out_channels, in_channels-out_channels,
kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels, out_channels, n_convs, kernel_size, padding)

def forward(self, x: torch.Tensor, skipped_x: torch.Tensor) -> torch.Tensor:
x = self.up_sample(x)
x = torch.cat([x, skipped_x], dim=1)
x = self.conv(x)
return x

Finally we can put it all together in a U-Net show:

class UNet(nn.Module):

def __init__(self, in_channels: int = 3, out_classes: int = 2) -> None:
super(UNet, self).__init__()
self.down_0 = DownBlock(in_channels, 64)
self.down_1 = DownBlock(64, 128)
self.down_2 = DownBlock(128, 256)
self.down_3 = DownBlock(256, 512)
self.bottleneck = ConvBlock(512, 1024)
self.up_0 = UpBlock(1024 + 512, 512)
self.up_1 = UpBlock(512 + 256, 256)
self.up_2 = UpBlock(256 + 128, 128)
self.up_3 = UpBlock(128 + 64, 64)
self.final_conv = nn.Conv2d(64, out_classes, kernel_size=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, skipped_x_0 = self.down_0(x)
x, skipped_x_1 = self.down_1(x)
x, skipped_x_2 = self.down_2(x)
x, skipped_x_3 = self.down_3(x)
x = self.bottleneck(x)
x = self.up_0(x, skipped_x_3)
x = self.up_1(x, skipped_x_2)
x = self.up_2(x, skipped_x_1)
x = self.up_3(x, skipped_x_0)
return self.final_conv(x)

Guidance and evaluation

Finally, we can train our model, and for convenience, we will use it segmentation_models_pytorch to do this. Note that it is best to use raw PyTorch, Junior High Just saves us some hassle with applying metrics, loss functions and training loops ourselves.

First, we will initialize our model, define training hypermeters (feel free to tweak and play with them), and initialize a loss, index, optimization, and data load function.

model = UNet(out_classes=2)
n_epochs = 15
batch_size = 32
lr = 5e-5

loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=lr)
])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(val_dataset, batch_size=1)

Second, hoist Junior HighThe functionality of initialize instances of training and validation periods.

train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=device,
verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=device,
verbose=True,
)

And finally, we can train our model. Below you can see output from my sample run:

for i in range(0, n_epochs):
print(f'nEpoch: i')
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(validation_loader)

Figure 3

To test our model and ensure we did not over-practice, we could run a single verification era with recently initialized test data loaders

test_loader = DataLoader(test_dataset, batch_size=1)
test_logs = valid_epoch.run(test_loader)

Figure 4

Below you can see how the model works on a test set:

Distribution

Segmentation image

Summary

Image segmentation continues to be one of the most critical areas of research in computer vision. Today it is dominated by Convolutional neural networks as they allow us to push further and further the scope of what is possible in computer vision. In addition, the example above demonstrates that good performance model training is relatively simple if you have a suitable data set and do not require as many resources.

call for action

Need the support of Python experts to implement solutions in your app? Our skilled developers will be happy to share their knowledge and experience.


Source

spot_img

LEAVE A REPLY

Please enter your comment!
Please enter your name here