top of page
Writer's pictureCaleb Robey

Segment Anything by Meta - All the Masks for Free

Updated: Sep 29, 2023

The new first step for virtually all computer vision problems.


Your data is basically perfect, right?


About 2 months ago, I was looking at a fairly standard computer vision problem for a client: build a computer vision algorithm that can detect the dimensions of an object in front of a checkerboard with known checkerboard square sizes. If I could just count the number of object-adjacent checkers and do some adjustment for various distortions, the problem would be solved.


Also, finding a checkerboard is a problem that is well-documented within the computer vision community. And it is easily solved so long as:

  1. The image includes complete rows and columns of the checkerboard.

  2. The image's squares aren't too small.

  3. The X-corner location precision has minimal noise and distortion.

  4. The image was taken with a high-resolution camera.

  5. The checkerboard is in ideal lighting with high contrast.

  6. + 10 other common requirements...

Of course, all image data that we could receive as data scientists/engineers meet these requirements. Right?


You get the idea. Data quality is usually somewhere in the range from non-ideal to bad.


... enter SAM from Meta.


SAM as a Computer Vision Tool: Accurate, Tunable, Robust

SAM is an accurate and tunable model that Meta trained on over 1-billion masks on 11-million images. It takes an input of an image (and some hints if you'd like to provide them) and outputs masks of all of the objects of interest in an image. By objects of interest, I really just mean anything that a human annotator would consider an object in the image.

SAM as applied to a cityscape: from https://segment-anything.com/

SAM has proven to be stable when applied across variable conditions and has even demonstrated the ability to mask objects that it never saw when being trained fairly well.


If results like those can be generated in a stable manner, why shouldn't it be the starting point for every computer vision analysis problem? Instead of just RGB pixels, you get to start with a set of masks (as a human would define them) within an image. I'll take it.


SAM in Action: The Hunt for Earbud Covers

Let's presume that I would like to find the earbud covers in the image below. Could I do that in multiple lighting conditions? You can follow along (or make the use-case more complex) in this google colab notebook.


I took these three images to compare.


One in fairly neutral lighting.

One with a white flooded light.

And one with an orange flooded light.


Then I ran them all through SAM (their paths are stored in 'normal', 'orange_exposed', 'white_exposed' variables):

import torch
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2

# Load the SAM model
sam = sam_model_registry["vit_h"](checkpoint=sam_model)
if torch.cuda.is_available():
  sam.to('cuda')

mask_dictionary = {}

for x in [normal, orange_exposed, white_exposed]:
  image = cv2.imread(str(x))
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=16)
  masks = mask_generator.generate(image)
  mask_dictionary[x.stem] = {'image': image, 'masks': masks}

And the fully-segmented results are:

Any location on the image where there is a colored "tint" is where SAM has found an object or segmentation of interest. Clearly, the white and orange exposure affected the overall mask view to some degree, but the lighting changes did not appear to change the core objects that were segmented out.


To answer that question for sure, we need to do further analysis. Let's check on those earbud covers.


Integration with Classical Computer Vision Techniques

This step is where classical computer vision tools suddenly become very effective by way of transformations, analysis, and feature engineering. So let's give feature engineering a try to find the earbud tips.


It would be helpful to first find all of the circular masks in those images. Thus, using commonly known CV tactics, we define the "circularity score" of each mask using this function:

def circularity_score(mask_metadata):
    mask = mask_metadata['segmentation']
    label_image = measure.label(mask)
    props = regionprops(label_image)
    
    if len(props) == 0:
        return 0
    
    area = props[0].area
    perimeter = props[0].perimeter
    
    if perimeter == 0:
        return 0
    
    circularity = (4 * np.pi * area) / (perimeter ** 2)
    if circularity > 1:
      circularity = -1
    return circularity

Then, let's look at the top 9 circle masks:

for dict_idx, (key, subdict) in enumerate(mask_dictionary.items()):
  plt.title(key)
  plt.imshow(subdict['image'])
  plt.imshow(plot_masks[dict_idx])
  plt.axis('off')
  plt.show()
  # sort the masks by stability_score
  fig, ax = plt.subplots(3, 3, figsize=(10,10))
  sorted_masks = sorted(subdict['masks'], key=lambda x: circularity_score(x), reverse=True)
  for idx, mask in enumerate(sorted_masks[:9]):
    ax[idx//3, idx%3].set_title(f"Circularity: {circularity_score(mask)}")
    ax[idx//3, idx%3].imshow(mask['segmentation'])
    ax[idx//3, idx%3].axis('off')
  subdict['circ_masks'] = sorted_masks[:9]
  plt.tight_layout()
  plt.show()

And thus we get the following masks for the normal image:

for the white exposed image:

and for the orange exposed image:

Great, we still have the earbuds showing up in our top 5/6 results of circularity, and they're the only duplicated circular objects. To finish the job, we just need to count the number of contours in each mask and find the mask with 2 contours.

def count_contours(mask):
    binary_mask = mask['segmentation'].astype(np.uint8)
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)
  
for dict_idx, (key, subdict) in enumerate(mask_dictionary.items()):
  for circular_mask in subdict['circ_masks']:
    if count_contours(circular_mask) > 1:
      plt.title(f"I found the earbuds! ({key})")
      plt.imshow(circular_mask['segmentation'])
      plt.show()

drum roll....

Stop Wasting Time and Money on Annotation!

This is a simple use case. However, the principle is that you may not need to train a Resnet with 50,000 labeled images anymore to get good environment-agnostic segmentation results. SAM is good enough to get you 90% of the way to the end of your problem and feature engineering can likely finish the job.


Even if computer vision isn't sufficient to finish the task of finding the mask(s) of interest, why not just train a neural net on the masks instead of the RGB images? It could easily reduce your training data needs by 10-fold and reduce the complexity of labeling drastically.


We have helped multinational corporations and small businesses solve these types of problems. Please reach out to info@depotanalytics.co if you would like to talk!

Commentaires


bottom of page