SpaceTransformer Motivation

Introduction: Space Transforms Are Harder Than They Look

Standard pipelines—crop ROI → resample → analyze → map results back—are error-prone with mainstream libraries (SimpleITK, scipy.ndimage, PyTorch). We compare them and show how SpaceTransformer provides accurate transforms with concise code.

Test Setup

Synthetic 35×35 image with a 9×9 square and a central keypoint. Crop a 15×15 ROI (slightly offset), resample to 32×32, run segmentation + keypoint detection, map results back to the original grid.

import numpy as np
import matplotlib.pyplot as plt

def create_test_image():
    img = np.zeros((35,35), dtype=np.float32)
    c = 35 // 2
    half = 4
    img[c-half:c+half+1, c-half:c+half+1] = 1.0
    img[c, c] = 5
    return img

def get_segmentation(img, threshold=0.5):
    return (img >= threshold).astype(np.uint8)

def get_keypoint(img):
    candidates = np.array(np.where(img >= 3))
    if candidates.size > 0:
        y = candidates[0].mean()
        x = candidates[1].mean()
        return np.array([y, x])
    return np.array([0.0, 0.0])

ROI_START_Y = 10; ROI_START_X = 10; ROI_SIZE = 15
ROI_END_Y = ROI_START_Y + ROI_SIZE
ROI_END_X = ROI_START_X + ROI_SIZE
TARGET_SIZE = 32

original_img = create_test_image()
true_keypoint = np.array([17, 17])
print("Original shape:", original_img.shape)

Helper for plotting results.

def plot_result(original_img, segment_result, keypoint_result, method_name, true_keypoint):
    plt.figure(figsize=(8,6))
    plt.imshow(original_img, cmap='gray', alpha=0.7)
    if segment_result is not None:
        plt.contour(segment_result, levels=[0.5], colors='red', linewidths=2)
    if keypoint_result is not None and len(keypoint_result)>0:
        if keypoint_result.ndim == 1:
            plt.plot(keypoint_result[1], keypoint_result[0], 'ro', markersize=8, label='Detected')
        else:
            plt.plot(keypoint_result[0,1], keypoint_result[0,0], 'ro', markersize=8, label='Detected')
    plt.plot(true_keypoint[1], true_keypoint[0], 'g+', markersize=12, markeredgewidth=3, label='Ground Truth')
    if keypoint_result is not None and len(keypoint_result)>0:
        err = np.linalg.norm((keypoint_result if keypoint_result.ndim==1 else keypoint_result[0]) - true_keypoint)
        plt.text(0.02, 0.98, f'Error: {err:.3f}px', transform=plt.gca().transAxes,
                 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), va='top')
    plt.title(f'{method_name} Result')
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()

Pipeline Definition

  1. Crop: extract 15×15 ROI from 35×35 image
  2. Resample: resize ROI to 32×32
  3. Analyze: segmentation + keypoint detection
  4. Map back: transform results to original space

We implement with different libraries.

Method 1: SimpleITK

import SimpleITK as sitk

def process_with_simpleitk(img):
    sitk_img = sitk.GetImageFromArray(img)
    sitk_img.SetSpacing([1.0, 1.0]); sitk_img.SetOrigin([0.0, 0.0])

    roi_size = [ROI_SIZE, ROI_SIZE]
    roi_start = [ROI_START_X, ROI_START_Y]
    roi_img = sitk.RegionOfInterest(sitk_img, roi_size, roi_start)

    target_size = [TARGET_SIZE, TARGET_SIZE]
    original_spacing = roi_img.GetSpacing()
    physical_size = [roi_size[i] * original_spacing[i] for i in range(2)]
    target_spacing = [physical_size[i] / target_size[i] for i in range(2)]

    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(target_spacing)
    resampler.SetSize(target_size)
    resampler.SetOutputOrigin(roi_img.GetOrigin())
    resampler.SetOutputDirection(roi_img.GetDirection())
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampled_img = resampler.Execute(roi_img)
    resampled_array = sitk.GetArrayFromImage(resampled_img)

    segment = get_segmentation(resampled_array)
    keypoint = get_keypoint(resampled_array)

    scale_factor = np.array(roi_size) / np.array(target_size)
    keypoint_roi = keypoint * scale_factor
    keypoint_original = keypoint_roi + np.array([roi_start[1], roi_start[0]])

    segment_sitk = sitk.GetImageFromArray(segment.astype(np.float32))
    segment_sitk.SetSpacing(target_spacing)
    segment_sitk.SetOrigin(resampled_img.GetOrigin())

    back_resampler = sitk.ResampleImageFilter()
    back_resampler.SetOutputSpacing(original_spacing)
    back_resampler.SetSize(roi_size)
    back_resampler.SetOutputOrigin(roi_img.GetOrigin())
    back_resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    back_resampler.SetDefaultPixelValue(0)
    segment_roi = back_resampler.Execute(segment_sitk)
    segment_roi_array = sitk.GetArrayFromImage(segment_roi)

    segment_original = np.zeros(img.shape, dtype=np.uint8)
    segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi_array
    return segment_original, keypoint_original

print("=== SimpleITK ===")
sitk_segment, sitk_key = process_with_simpleitk(original_img)
plot_result(original_img, sitk_segment, sitk_key, "SimpleITK", true_keypoint)

Issues: verbose resampling setup, manual coordinate math, easy axis mistakes.

Method 2: scipy.ndimage

from scipy.ndimage import zoom

def process_with_scipy(img):
    roi = img[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
    factor = TARGET_SIZE / ROI_SIZE
    resampled = zoom(roi, factor, order=1, mode='constant', cval=0)
    segment = get_segmentation(resampled)
    keypoint = get_keypoint(resampled)

    keypoint_roi = keypoint / factor
    keypoint_original = keypoint_roi + [ROI_START_Y, ROI_START_X]

    segment_roi = zoom(segment.astype(np.float32), 1.0/factor, order=0, mode='constant', cval=0)
    segment_original = np.zeros_like(img, dtype=np.uint8)
    segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
    return segment_original, keypoint_original

print("=== scipy.ndimage ===")
scipy_segment, scipy_key = process_with_scipy(original_img)
plot_result(original_img, scipy_segment, scipy_key, "scipy.ndimage", true_keypoint)

Issues: manual scaling, rounding errors when reversing zoom.

Method 3: PyTorch interpolate

import torch
import torch.nn.functional as F

def process_with_pytorch(img):
    tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
    roi_tensor = tensor[:, :, ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
    resampled = F.interpolate(roi_tensor, size=(TARGET_SIZE, TARGET_SIZE), mode='bilinear', align_corners=False)
    resampled_arr = resampled.squeeze().numpy()

    segment = get_segmentation(resampled_arr)
    keypoint = get_keypoint(resampled_arr)

    scale = ROI_SIZE / TARGET_SIZE
    keypoint_original = keypoint * scale + [ROI_START_Y, ROI_START_X]

    segment_tensor = torch.from_numpy(segment.astype(np.float32)).unsqueeze(0).unsqueeze(0)
    segment_roi = F.interpolate(segment_tensor, size=(ROI_SIZE, ROI_SIZE), mode='nearest').squeeze().numpy().astype(np.uint8)

    segment_original = np.zeros_like(img, dtype=np.uint8)
    segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
    return segment_original, keypoint_original

print("=== PyTorch ===")
torch_segment, torch_key = process_with_pytorch(original_img)
plot_result(original_img, torch_segment, torch_key, "PyTorch interpolate", true_keypoint)

Issues: align_corners confusion, manual batch/channel management, visible mask offset.

Method 4: SpaceTransformer

from spacetransformer.core import Space, warp_point
from spacetransformer.torch import warp_image

def process_with_spacetransformer(img):
    original_space = Space(shape=[1]+list(img.shape), spacing=(1.0,1.0,1.0), origin=(0.0,0.0,0.0))
    target_space = (original_space
        .apply_bbox([(0,1),(ROI_START_Y,ROI_END_Y),(ROI_START_X,ROI_END_X)])
        .apply_shape((1, TARGET_SIZE, TARGET_SIZE)))

    resampled = warp_image(img[None], original_space, target_space,
                            mode='trilinear', pad_value=0, cuda_device='cpu', numpy=True)[0]
    segment = get_segmentation(resampled)
    keypoint_2d = get_keypoint(resampled)

    keypoint_3d = np.array([[0, keypoint_2d[0], keypoint_2d[1]]])
    segment_back = warp_image(segment[None], target_space, original_space,
                              mode='nearest', pad_value=0, cuda_device='cpu', numpy=True)[0]
    keypoint_back = warp_point(keypoint_3d, target_space, original_space)[0]
    return segment_back, keypoint_back[0,1:3]

print("=== SpaceTransformer ===")
st_segment, st_key = process_with_spacetransformer(original_img)
plot_result(original_img, st_segment, st_key, "SpaceTransformer", true_keypoint)

Results

  • Accuracy: SpaceTransformer yields zero offset. Other methods show varying coordinate drifts; PyTorch visibly shifts the mask.
  • Developer effort: SpaceTransformer uses a declarative space description; code is concise. Others require manual bookkeeping.
  • Design: It separates transformation planning from execution, hiding fragile math and ensuring consistent behavior.

Next chapter dives into SpaceTransformer’s internals.