spacetransformer.torch.affine_builder

torch.affine_builder

Grid generation utilities for PyTorch grid_sample operations.

This module provides functions for generating 5D grids compatible with PyTorch’s grid_sample function, supporting both regular and half-precision computations for 3D medical image processing.

Example: Generate a grid for image sampling:

>>> import torch
>>> from spacetransformer.torch.affine_builder import build_grid
>>> 
>>> # Create affine transformation matrix
>>> theta = torch.eye(3, 4)  # Identity transformation
>>> shape = (100, 100, 50)
>>> 
>>> # Build grid for sampling
>>> grid = build_grid(theta, shape)
>>> print(grid.shape)
torch.Size([1, 100, 100, 50, 3])

Functions

Name Description
build_grid Generate 5D grid for PyTorch grid_sample operations.

build_grid

torch.affine_builder.build_grid(theta, shape, *, half=False)

Generate 5D grid for PyTorch grid_sample operations.

This function creates a 5D sampling grid by applying affine transformations to a normalized base grid. The grid is suitable for direct use with PyTorch’s F.grid_sample function.

Args: theta: Affine transformation matrix(es) with shape (3, 4) or (N, 3, 4). Row order is fixed as (x, y, z) shape: Target volume dimensions (D, H, W) half: Whether to use float16 precision for the grid

Returns: torch.Tensor: 5D sampling grid with shape (N, D, H, W, 3)

Example: Single transformation:

>>> import torch
>>> theta = torch.eye(3, 4)  # Identity transformation
>>> grid = build_grid(theta, (50, 100, 100))
>>> print(grid.shape)
torch.Size([1, 50, 100, 100, 3])

Batch of transformations:

>>> batch_theta = torch.eye(3, 4).unsqueeze(0).repeat(5, 1, 1)
>>> batch_grid = build_grid(batch_theta, (50, 100, 100))
>>> print(batch_grid.shape)
torch.Size([5, 50, 100, 100, 3])

Half precision for memory efficiency:

>>> half_grid = build_grid(theta, (50, 100, 100), half=True)
>>> print(half_grid.dtype)
torch.float16