import numpy as np
import matplotlib.pyplot as plt
def create_test_image():
img = np.zeros((35,35), dtype=np.float32)
c = 35 // 2
half = 4
img[c-half:c+half+1, c-half:c+half+1] = 1.0
img[c, c] = 5
return img
def get_segmentation(img, threshold=0.5):
return (img >= threshold).astype(np.uint8)
def get_keypoint(img):
candidates = np.array(np.where(img >= 3))
if candidates.size > 0:
y = candidates[0].mean()
x = candidates[1].mean()
return np.array([y, x])
return np.array([0.0, 0.0])
ROI_START_Y = 10; ROI_START_X = 10; ROI_SIZE = 15
ROI_END_Y = ROI_START_Y + ROI_SIZE
ROI_END_X = ROI_START_X + ROI_SIZE
TARGET_SIZE = 32
original_img = create_test_image()
true_keypoint = np.array([17, 17])
print("Original shape:", original_img.shape)SpaceTransformer Motivation
Introduction: Space Transforms Are Harder Than They Look
Standard pipelines—crop ROI → resample → analyze → map results back—are error-prone with mainstream libraries (SimpleITK, scipy.ndimage, PyTorch). We compare them and show how SpaceTransformer provides accurate transforms with concise code.
Test Setup
Synthetic 35×35 image with a 9×9 square and a central keypoint. Crop a 15×15 ROI (slightly offset), resample to 32×32, run segmentation + keypoint detection, map results back to the original grid.
Helper for plotting results.
def plot_result(original_img, segment_result, keypoint_result, method_name, true_keypoint):
plt.figure(figsize=(8,6))
plt.imshow(original_img, cmap='gray', alpha=0.7)
if segment_result is not None:
plt.contour(segment_result, levels=[0.5], colors='red', linewidths=2)
if keypoint_result is not None and len(keypoint_result)>0:
if keypoint_result.ndim == 1:
plt.plot(keypoint_result[1], keypoint_result[0], 'ro', markersize=8, label='Detected')
else:
plt.plot(keypoint_result[0,1], keypoint_result[0,0], 'ro', markersize=8, label='Detected')
plt.plot(true_keypoint[1], true_keypoint[0], 'g+', markersize=12, markeredgewidth=3, label='Ground Truth')
if keypoint_result is not None and len(keypoint_result)>0:
err = np.linalg.norm((keypoint_result if keypoint_result.ndim==1 else keypoint_result[0]) - true_keypoint)
plt.text(0.02, 0.98, f'Error: {err:.3f}px', transform=plt.gca().transAxes,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), va='top')
plt.title(f'{method_name} Result')
plt.legend(); plt.grid(True, alpha=0.3)
plt.tight_layout(); plt.show()Pipeline Definition
- Crop: extract 15×15 ROI from 35×35 image
- Resample: resize ROI to 32×32
- Analyze: segmentation + keypoint detection
- Map back: transform results to original space
We implement with different libraries.
Method 1: SimpleITK
import SimpleITK as sitk
def process_with_simpleitk(img):
sitk_img = sitk.GetImageFromArray(img)
sitk_img.SetSpacing([1.0, 1.0]); sitk_img.SetOrigin([0.0, 0.0])
roi_size = [ROI_SIZE, ROI_SIZE]
roi_start = [ROI_START_X, ROI_START_Y]
roi_img = sitk.RegionOfInterest(sitk_img, roi_size, roi_start)
target_size = [TARGET_SIZE, TARGET_SIZE]
original_spacing = roi_img.GetSpacing()
physical_size = [roi_size[i] * original_spacing[i] for i in range(2)]
target_spacing = [physical_size[i] / target_size[i] for i in range(2)]
resampler = sitk.ResampleImageFilter()
resampler.SetOutputSpacing(target_spacing)
resampler.SetSize(target_size)
resampler.SetOutputOrigin(roi_img.GetOrigin())
resampler.SetOutputDirection(roi_img.GetDirection())
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
resampled_img = resampler.Execute(roi_img)
resampled_array = sitk.GetArrayFromImage(resampled_img)
segment = get_segmentation(resampled_array)
keypoint = get_keypoint(resampled_array)
scale_factor = np.array(roi_size) / np.array(target_size)
keypoint_roi = keypoint * scale_factor
keypoint_original = keypoint_roi + np.array([roi_start[1], roi_start[0]])
segment_sitk = sitk.GetImageFromArray(segment.astype(np.float32))
segment_sitk.SetSpacing(target_spacing)
segment_sitk.SetOrigin(resampled_img.GetOrigin())
back_resampler = sitk.ResampleImageFilter()
back_resampler.SetOutputSpacing(original_spacing)
back_resampler.SetSize(roi_size)
back_resampler.SetOutputOrigin(roi_img.GetOrigin())
back_resampler.SetInterpolator(sitk.sitkNearestNeighbor)
back_resampler.SetDefaultPixelValue(0)
segment_roi = back_resampler.Execute(segment_sitk)
segment_roi_array = sitk.GetArrayFromImage(segment_roi)
segment_original = np.zeros(img.shape, dtype=np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi_array
return segment_original, keypoint_original
print("=== SimpleITK ===")
sitk_segment, sitk_key = process_with_simpleitk(original_img)
plot_result(original_img, sitk_segment, sitk_key, "SimpleITK", true_keypoint)Issues: verbose resampling setup, manual coordinate math, easy axis mistakes.
Method 2: scipy.ndimage
from scipy.ndimage import zoom
def process_with_scipy(img):
roi = img[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
factor = TARGET_SIZE / ROI_SIZE
resampled = zoom(roi, factor, order=1, mode='constant', cval=0)
segment = get_segmentation(resampled)
keypoint = get_keypoint(resampled)
keypoint_roi = keypoint / factor
keypoint_original = keypoint_roi + [ROI_START_Y, ROI_START_X]
segment_roi = zoom(segment.astype(np.float32), 1.0/factor, order=0, mode='constant', cval=0)
segment_original = np.zeros_like(img, dtype=np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
return segment_original, keypoint_original
print("=== scipy.ndimage ===")
scipy_segment, scipy_key = process_with_scipy(original_img)
plot_result(original_img, scipy_segment, scipy_key, "scipy.ndimage", true_keypoint)Issues: manual scaling, rounding errors when reversing zoom.
Method 3: PyTorch interpolate
import torch
import torch.nn.functional as F
def process_with_pytorch(img):
tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
roi_tensor = tensor[:, :, ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
resampled = F.interpolate(roi_tensor, size=(TARGET_SIZE, TARGET_SIZE), mode='bilinear', align_corners=False)
resampled_arr = resampled.squeeze().numpy()
segment = get_segmentation(resampled_arr)
keypoint = get_keypoint(resampled_arr)
scale = ROI_SIZE / TARGET_SIZE
keypoint_original = keypoint * scale + [ROI_START_Y, ROI_START_X]
segment_tensor = torch.from_numpy(segment.astype(np.float32)).unsqueeze(0).unsqueeze(0)
segment_roi = F.interpolate(segment_tensor, size=(ROI_SIZE, ROI_SIZE), mode='nearest').squeeze().numpy().astype(np.uint8)
segment_original = np.zeros_like(img, dtype=np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
return segment_original, keypoint_original
print("=== PyTorch ===")
torch_segment, torch_key = process_with_pytorch(original_img)
plot_result(original_img, torch_segment, torch_key, "PyTorch interpolate", true_keypoint)Issues: align_corners confusion, manual batch/channel management, visible mask offset.
Method 4: SpaceTransformer
from spacetransformer.core import Space, warp_point
from spacetransformer.torch import warp_image
def process_with_spacetransformer(img):
original_space = Space(shape=[1]+list(img.shape), spacing=(1.0,1.0,1.0), origin=(0.0,0.0,0.0))
target_space = (original_space
.apply_bbox([(0,1),(ROI_START_Y,ROI_END_Y),(ROI_START_X,ROI_END_X)])
.apply_shape((1, TARGET_SIZE, TARGET_SIZE)))
resampled = warp_image(img[None], original_space, target_space,
mode='trilinear', pad_value=0, cuda_device='cpu', numpy=True)[0]
segment = get_segmentation(resampled)
keypoint_2d = get_keypoint(resampled)
keypoint_3d = np.array([[0, keypoint_2d[0], keypoint_2d[1]]])
segment_back = warp_image(segment[None], target_space, original_space,
mode='nearest', pad_value=0, cuda_device='cpu', numpy=True)[0]
keypoint_back = warp_point(keypoint_3d, target_space, original_space)[0]
return segment_back, keypoint_back[0,1:3]
print("=== SpaceTransformer ===")
st_segment, st_key = process_with_spacetransformer(original_img)
plot_result(original_img, st_segment, st_key, "SpaceTransformer", true_keypoint)Results
- Accuracy: SpaceTransformer yields zero offset. Other methods show varying coordinate drifts; PyTorch visibly shifts the mask.
- Developer effort: SpaceTransformer uses a declarative space description; code is concise. Others require manual bookkeeping.
- Design: It separates transformation planning from execution, hiding fragile math and ensuring consistent behavior.
Next chapter dives into SpaceTransformer’s internals.