Learning

Paper Review | Segment Anything, SAM

Concept, Task, Model, Data, Experimental Result

2023
.
04
.
26
by
Suho Cho
Paper Review | Segment Anything, SAM

Recently, LLMs have been creating very impressive performance with zero-shot learning and few-shot learning.

Segment Anything Model (hereinafter SAM), recently announced by meta, aims to create a foundation model for image segmentation, and we are trying to complement the problems of various types of image segmentation through Prompt Engineering.

I wish I had developed it myself, but let's take a closer look at the paper to learn quickly and improve the data quality of DataHunt.

Paper link: Segment Anything
Github link: Segment Anything Github

Segment Anything Introduction

segment anything overview
We aim to build a foundation model for segmentation by introducing three interconnected components

 

Foundation models are collectively referred to as giant models that are pre-trained on huge datasets, regardless of the domain. These models show tremendous generalizability to the task at hand. In short, we can say that they are models with massive understanding.

The purpose of this thesis is to develop a foundation model for Image Segmentation. It is zero-shot, meaning that it is capable of objects that the model has not been trained on.

As with our recent series, there are some prerequisites: Task, Model, and Data. The authors of the paper ask three questions in this regard

  1. What task will enable zero-shot generalization?
  2. What is the corresponding model architecture?
  3. What data can power this task and model?

Segment Anything Task

We call ChatGPT a prompt-driven model in the sense that when a user orders something, we produce output accordingly. But can we do the same with Segmentation?

There are many different types of prompts, and we've designed ours to accept points, boxes, and text as input. In fact, segmentation by pointing is not a new approach. Work like Deep Extreme Cut(DEXTR), RITM, FocalClick have been done under the name Interactive Segmentation.

Segment Anything Model

The model can be broken down into the Image encoder, Prompt encoder, and Mask decoder areas. For convenience, we will group Prompt encoder / Mask decoder together and call it Prompt model. Image encoder is a ViT (Vision Transformer) based model for processing high-resolution images, which requires relatively longer time for inference. In our tests, the vit_l model took 0.669s (RTX 3090) and 8.257s (CPU) for a single 1080 x 1920 image. The Prompt model, on the other hand, has a processing time low enough to enable real-time inference. It also shares the characteristic that the same image embedding can be reused for multiple promptable tasks (point, bounding box, mask, text).

A characteristic of these behemoths is that they all have a powerful encoder. After such a great encoder, the back end is relatively lightweight. The model named in this paper is the Segment Anything Model (SAM), and it consists of three main parts.

  • Powerful Image Encoder
  • Prompt Encoder
  • Mask Decoder

Segment Anything Model (SAM) overview
Segment Anything Model (SAM) overview

Compared to the Image encoder, the Prompt encoder and Mask decoder are so lightweight that they're claimed to run in less than 50ms on the web! Let's take a look at how they do it.

Segment Anything Data

Before we talk about datasets, we need to mention the Data Engine. GPT is all over the internet, but the masks needed for segmentation are not easy to find. So we created a Data Engine, which consists of 3 steps.

  1. Assisted-manual: The operator makes a dot and the SAM model creates a mask to some extent. This is the assisted labeling.
  2. Semi-automatic: SAM creates a mask on its own for a specific set of objects, while the human creates a mask for other objects at the same time. For example, you can tell SAM "Generate masks for all apples in an image" and while it's making the apple masks, a human is working on the other apples.
  3. Fully automatic: You're taking grid points in an image and masking everything on your own.

In particular, this is what Fully automatic looks like.

Segment Anything model automatic

Fully automatic needs a little tweaking, but the results are amazing!

This is how they were able to collect so much data. 1B masks from 11M images

Segment Anything Method

We've explained that this model consists of three components. Let's go through them one by one

Image Encoder

We used the Masked Autoencoder (MAE) pre-trained ViT. I don't think I need to say much more about this model because it's already proven itself. When an image comes in, it goes through the Image encoder to get an embedding. Whatever you do with it afterward, that's where the Image encoder ends.

Prompt Encoder

There are two kinds of prompts, let's take a look at each and how they are created,

Sparse prompt

  • Points: Positional encodings summarized with learned embeddings.
  • Boxes: Positional encodings summed with learned embeddings
  • Text: CLIP output


def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    """Embeds box prompts."""
    boxes = boxes + 0.5  # Shift to center of pixel
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    return corner_embedding

sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
    coords, labels = points
    point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
    sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
    box_embeddings = self._embed_boxes(boxes)
    sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

Dense prompt

  • Mask: Perform convolution on image embedding


self.mask_downscaling = nn.Sequential(
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans // 4),
    activation(),
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans),
    activation(),
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)

The code is quite long, so I've excerpted some of it, the source is here.

Mask decoder

This is the part that takes the Image embedding and the prompt embedding and predicts the mask. Basically, we're using Prompt Self-attention and Cross-attention in the Transformer decoder block in a bidirectional way (what I mean by bidirectional is Prompt-to-image, Image-to-prompt). The reason why it's bidirectional is that both Image embedding and Prompt embedding need to be updated.

Details of the lightweight mask decoder

Keep in mind that this model is interactive with the Prompt.

It then performs upsampling to fit the image size and determines for each pixel whether it should be included in the mask.

Importantly, it doesn't generate labels! It just creates a mask.

Resolving Ambiguity

But if we take a dot, is the mask it points to the only one? What I mean is that when we dot a pixel with a human fingernail, it's hard to know if we want to get exactly that fingernail, or the hand, or a mask of the whole person. Normally, when you get multiple mask candidates, you average the confidence of each pixel to get a single mask, which is noisy and doesn't give you the correct mask of the person.

So! The paper suggests to give 3 mask candidates (3 is a good number) and backpropagate on the one with the minimum loss.

Mask ambiguity. 같은 점이라도 서로 다른 마스크가 의도될 수 있다.
Mask ambiguity. Different masks may be intended for the same dot.

Losses and Training

Above, we mentioned learning the mask with the minimum loss. In this thesis, loss is a linear combination of focal loss and dice loss (see Appendix A.2).

  • Focal loss is the idea of weighted learning on more difficult objects.
  • Dice loss is the concept of focusing on recall a bit more than the familiar IoU. Dice is a function of how much more GT was included than IoU.

Dice / IoU 비교 그림
Dice / IoU comparison illustration. Source - https://ilmonteux.github.io/2019/05/10/segmentation-metrics.html

SA-1B Dataset

Another important point in this paper is that in addition to task and model, we have data.

  • Images: We collected 11M high-resolution images taken by a photographer, with an average resolution of a whopping 3300x4950.
  • Masks: 1.1B masks were obtained from the above images, 99.1% of which were obtained by a fully automatic method. But what about the quality? I sampled about 500 images and had them corrected by an expert, and the average IoU before and after correction was 0.9. That sounds pretty good to me.
  • Mask Quality: To evaluate the quality of the masks, I randomly sampled 500 images (with about 50k masks). We corrected this sampled data and obtained an auto-generated mask and a ground truth mask for it, and measured the IoU of the two and found that they were 85-91% IoU. It looks like the mask generated by Human Rating is pretty good.

SA-1B 데이터셋 통계. 지정학적 통계까지 포함되어 있다
SA-1B dataset statistics. Even geopolitical statistics are included

Experimental Results

The results are awesome. Of course, if you want to use this model for data labeling, that's another story, but the fact that it's zero-shot makes it incredibly versatile.

Zero-Shot Single Point Valid Mask Evaluation

Segment Anything Result Example
Segment Anything Result Example

I also didn't forget to compare it to other existing interactive segmentations, and it shows a very good overall performance compared to the RITM announced by Samsung. In Figure (a), the GTEA dataset at the bottom is the Georgia Tech Egocentric Activity Datasets, which is a dataset that captures daily activity from a first-person perspective. It's a particularly poor performer, so I took a closer look, but didn't find anything relevant.

Zero-Shot Single Point Valid Mask Evaluation
Zero-Shot Single Point Valid Mask Evaluation

Zero-Shot Instance Segmentation

Comparing the performance of ViTDet and SAM on the COCO and LVIS datasets, we found that SAM falls short of ViTDet quantitatively, but does better on Human Rating.

Zero-Shot Instance Segmentation

Zero-Shot Text-to-Mask

The task experiment of performing segmentation from free-from text is a kind of proof of concept to test the performance of the model. It can be used in conjunction with click information, such as PhraseClick, when text input alone does not provide accurate output.

Zero-Shot Text-to-Mask
Zero-Shot Text-to-Mask

There are a few other details, but here's the short version

  • The model was trained for 3-5 days on 256 A100 GPUs
  • The image encoder takes ~0.15 seconds on an NVIDIA A100 GPU
  • The prompt encoder and mask decoder take ~50ms on CPU in the browser using multithreaded SIMD execution
  • The image encoder has 632M parameters
  • The prompt encoder and mask decoder have 4M parameters

Conclusion

The model calls itself the "Foundation Model for Image Segmentation" and I think it deserves it. Plus, once I got the image embedding down, I found it useful to be able to use prompt to create masks very quickly and freely. I hope to be able to try it out soon and let you know what I think.

Talk to Expert