Skip to content

Latest commit

 

History

History
200 lines (159 loc) · 8.66 KB

CustomInference.md

File metadata and controls

200 lines (159 loc) · 8.66 KB

This document contains examples of applying our models on custom images.

If you're fine-tuning models on downstream tasks, check the Extracting Representations Example, which shows how to use the pre-trained backbones independently of this codebase.

If you want to compute model outputs, check the High-Resolution or Sentinel-2 inference example.

Extracting Representations Example

In this example we will load pre-trained backbone without using this codebase. If you want to use our model architecture code, then the other examples below may be more helpful.

Here's code to restore the Swin-v2-Base backbone of a single-image model for application to downstream tasks:

import torch
import torchvision
model = torchvision.models.swin_transformer.swin_v2_b()
full_state_dict = torch.load('satlas-model-v1-highres.pth')
# Extract just the Swin backbone parameters from the full state dict.
swin_prefix = 'backbone.backbone.'
swin_state_dict = {k[len(swin_prefix):]: v for k, v in full_state_dict.items() if k.startswith(swin_prefix)}
model.load_state_dict(swin_state_dict)

See Normalization.md for documentation on how images should be normalized for input to Satlas models.

Feature representations can be extracted like this:

# Assume im is shape (C, H, W).
x = im[None, :, :, :]
outputs = []
for layer in model.features:
    x = layer(x)
    outputs.append(x.permute(0, 3, 1, 2))
map1, map2, map3, map4 = outputs[-7], outputs[-5], outputs[-3], outputs[-1]

Here's code to compute the feature representations from a multi-image model through max temporal pooling. Note the different prefix of the Swin backbone parameters. See here for model architecture details.

import torch
import torchvision
model = torchvision.models.swin_transformer.swin_v2_b()
# Make sure to load a multi-image model here.
# Only the multi-image models are trained to provide robust features after max temporal pooling.
full_state_dict = torch.load('satlas-model-v1-lowres-multi.pth')
# Extract just the Swin backbone parameters from the full state dict.
swin_prefix = 'backbone.backbone.backbone.'
swin_state_dict = {k[len(swin_prefix):]: v for k, v in full_state_dict.items() if k.startswith(swin_prefix)}
model.load_state_dict(swin_state_dict)

# Assume im is shape (N, C, H, W), with N aligned images of the same location at different times.
# First get feature maps of each individual image.
x = im
outputs = []
for layer in model.features:
    x = layer(x)
    outputs.append(x.permute(0, 3, 1, 2))
feature_maps = [outputs[-7], outputs[-5], outputs[-3], outputs[-1]]
# Now apply max temporal pooling.
feature_maps = [
    m.amax(dim=0)
    for m in feature_maps
]
# feature_maps can be passed to a head, and the head or entire model can be trained to fine-tune on task-specific labels.

High-Resolution Inference Example

In this example we will apply a single-image high-resolution model on a high-resolution image.

If you don't have an image already, see an example of obtaining one. We will assume the image is saved as image.jpg.`

We will assume you're using satlas-model-v1-highres.pth (pre-trained on SatlasPretrain). The expected input is 8-bit RGB image, and input values should be divided by 255 so they are between 0-1.

First, obtain the code and the model:

git clone https://github.com/allenai/satlas
mkdir models
wget -O models/satlas-model-v1-highres.pth https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-highres.pth

Load the model and apply the model, and extract its building predictions:

import json
import torch
import torchvision

import satlas.model.model
import satlas.model.evaluate

# Locations of model config and weights, and the 8-bit RGB image to run inference on.
config_path = 'configs/highres_pretrain_old.txt'
weights_path = 'models/satlas-model-v1-highres.pth'
image_path = 'image.jpg'

# Read config and initialize the model.
with open(config_path, 'r') as f:
    config = json.load(f)
device = torch.device("cuda")
for spec in config['Tasks']:
    if 'Task' not in spec:
        spec['Task'] = satlas.model.dataset.tasks[spec['Name']]
model = satlas.model.model.Model({
    'config': config['Model'],
    'channels': config['Channels'],
    'tasks': config['Tasks'],
})
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

# Read image, apply model, and save output visualizations.
with torch.no_grad():
    im = torchvision.io.read_image(image_path)
    gpu_im = im.to(device).float() / 255
    outputs, _ = model([gpu_im])

    for task_idx, spec in config['Tasks']:
        satlas.model.evaluate.visualize_outputs(
            task=spec['Task'],
            image=im.numpy().transpose(1, 2, 0),
            outputs=outputs[task_idx][0],
            vis_dir='./',
            save_prefix='out',
        )

See configs/highres_pretrain_old.txt for a list of all the heads of this model.

Sentinel-2 Inference Example

In this example we will apply a multi-image multi-band Sentinel-2 model on Sentinel-2 imagery.

If you don't have Sentinel-2 images merged and normalized for Satlas already, see the example. The example also documents the normalization of Sentinel-2 bands expected by our models.

We will assume you're using the solar farm model (models/solar_farm/best.pth) but you could use another model like satlas-model-v1-lowres-multi.pth (the SatlasPretrain model) instead.

First obtain the code and the model:

git clone https://github.com/allenai/satlas
cd satlas
wget https://pub-956f3eb0f5974f37b9228e0a62f449bf.r2.dev/satlas_explorer_datasets/satlas_explorer_datasets_2023-07-24.tar
tar xvf satlas_explorer_datasets_2023-07-24.tar

Now we can load the images, normalize them, and apply the model:

import json
import numpy as np
from osgeo import gdal
import skimage.io
import torch
import torchvision
import tqdm

import satlas.model.evaluate
import satlas.model.model

# Locations of model config and weights, and the input image.
config_path = 'configs/satlas_explorer_solar_farm.txt'
weights_path = 'satlas_explorer_datasets/models/solar_farm/best.pth'
image_path = 'stack.npy'

# Read config and initialize the model.
with open(config_path, 'r') as f:
    config = json.load(f)
device = torch.device("cuda")
for spec in config['Tasks']:
    if 'Task' not in spec:
        spec['Task'] = satlas.model.dataset.tasks[spec['Name']]
model = satlas.model.model.Model({
    'config': config['Model'],
    'channels': config['Channels'],
    'tasks': config['Tasks'],
})
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

image = np.load(image_path)
# For (N, C, H, W) image (with N timestamps), convert to (N*C, H, W).
image = image.reshape(image.shape[0]*image.shape[1], image.shape[2], image.shape[3])

# The image is large so apply it on windows.
# Here we collect the land cover outputs (head 2).
solar_farm_vis = np.zeros((image.shape[1], image.shape[2], 3), dtype=np.uint8)
crop_size = 2048
head_idx = 0

with torch.no_grad():
    for row in tqdm.tqdm(range(0, image.shape[1], crop_size)):
        for col in range(0, image.shape[2], crop_size):
            crop = torch.as_tensor(image[:, row:row+crop_size, col:col+crop_size])
            gpu_crop = crop.to(device).float() / 255
            outputs, _ = model([gpu_crop])
            # Convert binary segmentation probabilities to classes.
            pred_cls = outputs[head_idx][0, :, :, :].cpu().numpy() > 0.5
            crop_colored = satlas.model.evaluate.segmentation_mask_to_color(config['Tasks'][head_idx]['Task'], pred_cls)
            solar_farm_vis[row:row+crop_size, col:col+crop_size, :] = crop_colored

skimage.io.imsave('rgb.png', image[0:3, :, :].transpose(1, 2, 0))
skimage.io.imsave('solar_farm.png', solar_farm_vis)