spacetransformer.torch.image_warpers

torch.image_warpers

GPU-accelerated image resampling for 3D medical images.

This module provides efficient GPU-accelerated image warping capabilities for 3D medical images using PyTorch. It supports various interpolation modes and optimization strategies for different transformation scenarios.

Example: Basic image warping between spaces:

>>> import torch
>>> import numpy as np
>>> from spacetransformer.core import Space
>>> from spacetransformer.torch.image_warpers import warp_image
>>> 
>>> # Create test image and spaces
>>> image = torch.rand(100, 100, 50)
>>> source = Space(shape=(100, 100, 50), spacing=(1.0, 1.0, 2.0))
>>> target = Space(shape=(50, 50, 25), spacing=(2.0, 2.0, 4.0))
>>> 
>>> # Warp image
>>> warped = warp_image(image, source, target, pad_value=0.0)
>>> print(warped.shape)
torch.Size([50, 50, 25])

Functions

Name Description
warp_image Resample image from source space to target space using GPU acceleration.
warp_image_batch Batch resampling of a single image to multiple target spaces.
warp_image_with_argmax Not implemented yet.

warp_image

torch.image_warpers.warp_image(
    img,
    source,
    target,
    *,
    pad_value,
    mode='trilinear',
    pad_mode='constant',
    half=False,
    numpy=False,
    cuda_device='cuda:0',
)

Resample image from source space to target space using GPU acceleration.

This function performs efficient 3D image resampling from source to target coordinate space using various optimization strategies based on the geometric relationship between the spaces.

The function automatically chooses the best strategy: - Direct copy for identical spaces - Empty output for non-overlapping spaces
- Fast tensor operations for flip/permute transformations - Optimized interpolation for zoom operations - General grid sampling for arbitrary transformations

Args: img: Input image tensor or array. Supports 3D (D,H,W), 4D (C,D,H,W), or 5D (B,C,D,H,W) formats source: Source geometric space defining input image coordinates target: Target geometric space for output image coordinates pad_value: Padding value for regions outside source image bounds mode: Interpolation mode (“trilinear”, “nearest”, “bicubic”) pad_mode: Padding mode for boundary handling (“constant”, “reflect”, etc.) half: Whether to use half-precision (float16) for computation numpy: Whether to return numpy array instead of tensor cuda_device: CUDA device for GPU computation

Returns: Resampled image in target space with same type as input (unless numpy=True)

Raises: ValidationError: If input dimensions are invalid CudaError: If CUDA operations fail

Example: Basic image resampling:

>>> import torch
>>> from spacetransformer.core import Space
>>> from spacetransformer.torch.image_warpers import warp_image
>>> 
>>> # Create test image and spaces
>>> image = torch.rand(100, 100, 50)
>>> source = Space(shape=(100, 100, 50), spacing=(1.0, 1.0, 2.0))
>>> target = Space(shape=(50, 50, 25), spacing=(2.0, 2.0, 4.0))
>>> 
>>> # Resample to target space
>>> resampled = warp_image(image, source, target, pad_value=0.0)
>>> print(resampled.shape)
torch.Size([50, 50, 25])

Using different interpolation modes:

>>> # Nearest neighbor for label images
>>> labels = torch.randint(0, 5, (100, 100, 50))
>>> resampled_labels = warp_image(labels, source, target, 
...                              pad_value=0, mode="nearest")
>>> 
>>> # Half precision for memory efficiency
>>> resampled_half = warp_image(image, source, target, 
...                            pad_value=0.0, half=True)

warp_image_batch

torch.image_warpers.warp_image_batch(
    img,
    source,
    targets,
    *,
    pad_value,
    mode='trilinear',
    pad_mode='constant',
    half=False,
    cuda_device='cuda:0',
)

Batch resampling of a single image to multiple target spaces.

This function efficiently resamples the same input image to multiple target spaces, avoiding redundant memory transfers and data preparation. It’s useful for multi-resolution or multi-view processing scenarios.

Args: img: Input image tensor or array in 3D (D,H,W), 4D (C,D,H,W), or 5D (B,C,D,H,W) format source: Source space defining input image coordinates targets: List of target spaces, each defining an output space pad_value: Padding value for regions outside source image bounds mode: Interpolation mode (“trilinear”, “nearest”, “bicubic”) pad_mode: Padding mode for boundary handling (“constant”, “reflect”, etc.) half: Whether to use half-precision (float16) for computation cuda_device: CUDA device for GPU computation

Returns: List[torch.Tensor]: List of resampled images, always as GPU tensors

Raises: ValidationError: If input dimensions are invalid CudaError: If CUDA operations fail

warp_image_with_argmax

torch.image_warpers.warp_image_with_argmax(*args, **kwargs)

Not implemented yet.

This function is reserved for future implementation of combined image warping and argmax operations for efficient segmentation map resampling.

Raises: NotImplementedError: Always raised as this function is not yet implemented