spacetransformer.torch.image_warpers
torch.image_warpers
GPU-accelerated image resampling for 3D medical images.
This module provides efficient GPU-accelerated image warping capabilities for 3D medical images using PyTorch. It supports various interpolation modes and optimization strategies for different transformation scenarios.
Example: Basic image warping between spaces:
>>> import torch
>>> import numpy as np
>>> from spacetransformer.core import Space
>>> from spacetransformer.torch.image_warpers import warp_image
>>>
>>> # Create test image and spaces
>>> image = torch.rand(100, 100, 50)
>>> source = Space(shape=(100, 100, 50), spacing=(1.0, 1.0, 2.0))
>>> target = Space(shape=(50, 50, 25), spacing=(2.0, 2.0, 4.0))
>>>
>>> # Warp image
>>> warped = warp_image(image, source, target, pad_value=0.0)
>>> print(warped.shape)
torch.Size([50, 50, 25])
Functions
| Name | Description |
|---|---|
| warp_image | Resample image from source space to target space using GPU acceleration. |
| warp_image_batch | Batch resampling of a single image to multiple target spaces. |
| warp_image_with_argmax | Not implemented yet. |
warp_image
torch.image_warpers.warp_image(
img,
source,
target,
*,
pad_value,
mode='trilinear',
pad_mode='constant',
half=False,
numpy=False,
cuda_device='cuda:0',
)Resample image from source space to target space using GPU acceleration.
This function performs efficient 3D image resampling from source to target coordinate space using various optimization strategies based on the geometric relationship between the spaces.
The function automatically chooses the best strategy: - Direct copy for identical spaces - Empty output for non-overlapping spaces
- Fast tensor operations for flip/permute transformations - Optimized interpolation for zoom operations - General grid sampling for arbitrary transformations
Args: img: Input image tensor or array. Supports 3D (D,H,W), 4D (C,D,H,W), or 5D (B,C,D,H,W) formats source: Source geometric space defining input image coordinates target: Target geometric space for output image coordinates pad_value: Padding value for regions outside source image bounds mode: Interpolation mode (“trilinear”, “nearest”, “bicubic”) pad_mode: Padding mode for boundary handling (“constant”, “reflect”, etc.) half: Whether to use half-precision (float16) for computation numpy: Whether to return numpy array instead of tensor cuda_device: CUDA device for GPU computation
Returns: Resampled image in target space with same type as input (unless numpy=True)
Raises: ValidationError: If input dimensions are invalid CudaError: If CUDA operations fail
Example: Basic image resampling:
>>> import torch
>>> from spacetransformer.core import Space
>>> from spacetransformer.torch.image_warpers import warp_image
>>>
>>> # Create test image and spaces
>>> image = torch.rand(100, 100, 50)
>>> source = Space(shape=(100, 100, 50), spacing=(1.0, 1.0, 2.0))
>>> target = Space(shape=(50, 50, 25), spacing=(2.0, 2.0, 4.0))
>>>
>>> # Resample to target space
>>> resampled = warp_image(image, source, target, pad_value=0.0)
>>> print(resampled.shape)
torch.Size([50, 50, 25])
Using different interpolation modes:
>>> # Nearest neighbor for label images
>>> labels = torch.randint(0, 5, (100, 100, 50))
>>> resampled_labels = warp_image(labels, source, target,
... pad_value=0, mode="nearest")
>>>
>>> # Half precision for memory efficiency
>>> resampled_half = warp_image(image, source, target,
... pad_value=0.0, half=True)
warp_image_batch
torch.image_warpers.warp_image_batch(
img,
source,
targets,
*,
pad_value,
mode='trilinear',
pad_mode='constant',
half=False,
cuda_device='cuda:0',
)Batch resampling of a single image to multiple target spaces.
This function efficiently resamples the same input image to multiple target spaces, avoiding redundant memory transfers and data preparation. It’s useful for multi-resolution or multi-view processing scenarios.
Args: img: Input image tensor or array in 3D (D,H,W), 4D (C,D,H,W), or 5D (B,C,D,H,W) format source: Source space defining input image coordinates targets: List of target spaces, each defining an output space pad_value: Padding value for regions outside source image bounds mode: Interpolation mode (“trilinear”, “nearest”, “bicubic”) pad_mode: Padding mode for boundary handling (“constant”, “reflect”, etc.) half: Whether to use half-precision (float16) for computation cuda_device: CUDA device for GPU computation
Returns: List[torch.Tensor]: List of resampled images, always as GPU tensors
Raises: ValidationError: If input dimensions are invalid CudaError: If CUDA operations fail
warp_image_with_argmax
torch.image_warpers.warp_image_with_argmax(*args, **kwargs)Not implemented yet.
This function is reserved for future implementation of combined image warping and argmax operations for efficient segmentation map resampling.
Raises: NotImplementedError: Always raised as this function is not yet implemented