spacetransformer.torch.affine_builder
torch.affine_builder
Grid generation utilities for PyTorch grid_sample operations.
This module provides functions for generating 5D grids compatible with PyTorch’s grid_sample function, supporting both regular and half-precision computations for 3D medical image processing.
Example: Generate a grid for image sampling:
>>> import torch
>>> from spacetransformer.torch.affine_builder import build_grid
>>>
>>> # Create affine transformation matrix
>>> theta = torch.eye(3, 4) # Identity transformation
>>> shape = (100, 100, 50)
>>>
>>> # Build grid for sampling
>>> grid = build_grid(theta, shape)
>>> print(grid.shape)
torch.Size([1, 100, 100, 50, 3])
Functions
| Name | Description |
|---|---|
| build_grid | Generate 5D grid for PyTorch grid_sample operations. |
build_grid
torch.affine_builder.build_grid(theta, shape, *, half=False)Generate 5D grid for PyTorch grid_sample operations.
This function creates a 5D sampling grid by applying affine transformations to a normalized base grid. The grid is suitable for direct use with PyTorch’s F.grid_sample function.
Args: theta: Affine transformation matrix(es) with shape (3, 4) or (N, 3, 4). Row order is fixed as (x, y, z) shape: Target volume dimensions (D, H, W) half: Whether to use float16 precision for the grid
Returns: torch.Tensor: 5D sampling grid with shape (N, D, H, W, 3)
Example: Single transformation:
>>> import torch
>>> theta = torch.eye(3, 4) # Identity transformation
>>> grid = build_grid(theta, (50, 100, 100))
>>> print(grid.shape)
torch.Size([1, 50, 100, 100, 3])
Batch of transformations:
>>> batch_theta = torch.eye(3, 4).unsqueeze(0).repeat(5, 1, 1)
>>> batch_grid = build_grid(batch_theta, (50, 100, 100))
>>> print(batch_grid.shape)
torch.Size([5, 50, 100, 100, 3])
Half precision for memory efficiency:
>>> half_grid = build_grid(theta, (50, 100, 100), half=True)
>>> print(half_grid.dtype)
torch.float16