Test Setup
Synthetic 35×35 image with a 9×9 square and a central keypoint. Crop a 15×15 ROI (slightly offset), resample to 32×32, run segmentation + keypoint detection, map results back to the original grid.
import numpy as np
import matplotlib.pyplot as plt
def create_test_image():
img = np.zeros((35 ,35 ), dtype= np.float32)
c = 35 // 2
half = 4
img[c- half:c+ half+ 1 , c- half:c+ half+ 1 ] = 1.0
img[c, c] = 5
return img
def get_segmentation(img, threshold= 0.5 ):
return (img >= threshold).astype(np.uint8)
def get_keypoint(img):
candidates = np.array(np.where(img >= 3 ))
if candidates.size > 0 :
y = candidates[0 ].mean()
x = candidates[1 ].mean()
return np.array([y, x])
return np.array([0.0 , 0.0 ])
ROI_START_Y = 10 ; ROI_START_X = 10 ; ROI_SIZE = 15
ROI_END_Y = ROI_START_Y + ROI_SIZE
ROI_END_X = ROI_START_X + ROI_SIZE
TARGET_SIZE = 32
original_img = create_test_image()
true_keypoint = np.array([17 , 17 ])
print ("Original shape:" , original_img.shape)
Helper for plotting results.
def plot_result(original_img, segment_result, keypoint_result, method_name, true_keypoint):
plt.figure(figsize= (8 ,6 ))
plt.imshow(original_img, cmap= 'gray' , alpha= 0.7 )
if segment_result is not None :
plt.contour(segment_result, levels= [0.5 ], colors= 'red' , linewidths= 2 )
if keypoint_result is not None and len (keypoint_result)> 0 :
if keypoint_result.ndim == 1 :
plt.plot(keypoint_result[1 ], keypoint_result[0 ], 'ro' , markersize= 8 , label= 'Detected' )
else :
plt.plot(keypoint_result[0 ,1 ], keypoint_result[0 ,0 ], 'ro' , markersize= 8 , label= 'Detected' )
plt.plot(true_keypoint[1 ], true_keypoint[0 ], 'g+' , markersize= 12 , markeredgewidth= 3 , label= 'Ground Truth' )
if keypoint_result is not None and len (keypoint_result)> 0 :
err = np.linalg.norm((keypoint_result if keypoint_result.ndim== 1 else keypoint_result[0 ]) - true_keypoint)
plt.text(0.02 , 0.98 , f'Error: { err:.3f} px' , transform= plt.gca().transAxes,
bbox= dict (boxstyle= 'round' , facecolor= 'white' , alpha= 0.8 ), va= 'top' )
plt.title(f' { method_name} Result' )
plt.legend(); plt.grid(True , alpha= 0.3 )
plt.tight_layout(); plt.show()
Pipeline Definition
Crop: extract 15×15 ROI from 35×35 image
Resample: resize ROI to 32×32
Analyze: segmentation + keypoint detection
Map back: transform results to original space
We implement with different libraries.
Method 1: SimpleITK
import SimpleITK as sitk
def process_with_simpleitk(img):
sitk_img = sitk.GetImageFromArray(img)
sitk_img.SetSpacing([1.0 , 1.0 ]); sitk_img.SetOrigin([0.0 , 0.0 ])
roi_size = [ROI_SIZE, ROI_SIZE]
roi_start = [ROI_START_X, ROI_START_Y]
roi_img = sitk.RegionOfInterest(sitk_img, roi_size, roi_start)
target_size = [TARGET_SIZE, TARGET_SIZE]
original_spacing = roi_img.GetSpacing()
physical_size = [roi_size[i] * original_spacing[i] for i in range (2 )]
target_spacing = [physical_size[i] / target_size[i] for i in range (2 )]
resampler = sitk.ResampleImageFilter()
resampler.SetOutputSpacing(target_spacing)
resampler.SetSize(target_size)
resampler.SetOutputOrigin(roi_img.GetOrigin())
resampler.SetOutputDirection(roi_img.GetDirection())
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0 )
resampled_img = resampler.Execute(roi_img)
resampled_array = sitk.GetArrayFromImage(resampled_img)
segment = get_segmentation(resampled_array)
keypoint = get_keypoint(resampled_array)
scale_factor = np.array(roi_size) / np.array(target_size)
keypoint_roi = keypoint * scale_factor
keypoint_original = keypoint_roi + np.array([roi_start[1 ], roi_start[0 ]])
segment_sitk = sitk.GetImageFromArray(segment.astype(np.float32))
segment_sitk.SetSpacing(target_spacing)
segment_sitk.SetOrigin(resampled_img.GetOrigin())
back_resampler = sitk.ResampleImageFilter()
back_resampler.SetOutputSpacing(original_spacing)
back_resampler.SetSize(roi_size)
back_resampler.SetOutputOrigin(roi_img.GetOrigin())
back_resampler.SetInterpolator(sitk.sitkNearestNeighbor)
back_resampler.SetDefaultPixelValue(0 )
segment_roi = back_resampler.Execute(segment_sitk)
segment_roi_array = sitk.GetArrayFromImage(segment_roi)
segment_original = np.zeros(img.shape, dtype= np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi_array
return segment_original, keypoint_original
print ("=== SimpleITK ===" )
sitk_segment, sitk_key = process_with_simpleitk(original_img)
plot_result(original_img, sitk_segment, sitk_key, "SimpleITK" , true_keypoint)
Issues: verbose resampling setup, manual coordinate math, easy axis mistakes.
Method 2: scipy.ndimage
from scipy.ndimage import zoom
def process_with_scipy(img):
roi = img[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
factor = TARGET_SIZE / ROI_SIZE
resampled = zoom(roi, factor, order= 1 , mode= 'constant' , cval= 0 )
segment = get_segmentation(resampled)
keypoint = get_keypoint(resampled)
keypoint_roi = keypoint / factor
keypoint_original = keypoint_roi + [ROI_START_Y, ROI_START_X]
segment_roi = zoom(segment.astype(np.float32), 1.0 / factor, order= 0 , mode= 'constant' , cval= 0 )
segment_original = np.zeros_like(img, dtype= np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
return segment_original, keypoint_original
print ("=== scipy.ndimage ===" )
scipy_segment, scipy_key = process_with_scipy(original_img)
plot_result(original_img, scipy_segment, scipy_key, "scipy.ndimage" , true_keypoint)
Issues: manual scaling, rounding errors when reversing zoom.
Method 3: PyTorch interpolate
import torch
import torch.nn.functional as F
def process_with_pytorch(img):
tensor = torch.from_numpy(img).unsqueeze(0 ).unsqueeze(0 ).float ()
roi_tensor = tensor[:, :, ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X]
resampled = F.interpolate(roi_tensor, size= (TARGET_SIZE, TARGET_SIZE), mode= 'bilinear' , align_corners= False )
resampled_arr = resampled.squeeze().numpy()
segment = get_segmentation(resampled_arr)
keypoint = get_keypoint(resampled_arr)
scale = ROI_SIZE / TARGET_SIZE
keypoint_original = keypoint * scale + [ROI_START_Y, ROI_START_X]
segment_tensor = torch.from_numpy(segment.astype(np.float32)).unsqueeze(0 ).unsqueeze(0 )
segment_roi = F.interpolate(segment_tensor, size= (ROI_SIZE, ROI_SIZE), mode= 'nearest' ).squeeze().numpy().astype(np.uint8)
segment_original = np.zeros_like(img, dtype= np.uint8)
segment_original[ROI_START_Y:ROI_END_Y, ROI_START_X:ROI_END_X] = segment_roi
return segment_original, keypoint_original
print ("=== PyTorch ===" )
torch_segment, torch_key = process_with_pytorch(original_img)
plot_result(original_img, torch_segment, torch_key, "PyTorch interpolate" , true_keypoint)
Issues: align_corners confusion, manual batch/channel management, visible mask offset.
Results
Accuracy : SpaceTransformer yields zero offset. Other methods show varying coordinate drifts; PyTorch visibly shifts the mask.
Developer effort : SpaceTransformer uses a declarative space description; code is concise. Others require manual bookkeeping.
Design : It separates transformation planning from execution, hiding fragile math and ensuring consistent behavior.
Next chapter dives into SpaceTransformer’s internals.