spacetransformer.torch.utils

torch.utils

Utility functions for PyTorch operations in SpaceTransformer.

This module provides utility functions for PyTorch operations, including tensor dimension normalization, type conversion, and device management.

Functions

Name Description
norm_dim Normalize tensor dimensions to 5D (batch, channel, depth, height, width).
norm_type Normalize tensor type, device, and precision.

norm_dim

torch.utils.norm_dim(tensor)

Normalize tensor dimensions to 5D (batch, channel, depth, height, width).

This function converts input tensors of various dimensions to a standard 5D format used in medical image processing. This simplifies operations by ensuring consistent dimension ordering.

Args: tensor: Input tensor of dimensions 3D, 4D, or 5D - 3D: interpreted as (depth, height, width) - 4D: interpreted as (channel, depth, height, width) - 5D: interpreted as (batch, channel, depth, height, width)

Returns: torch.Tensor: Normalized 5D tensor

Raises: ValueError: If input dimensions are invalid (< 3D or > 5D)

Example: >>> import torch >>> img3d = torch.rand(50, 100, 100) # D,H,W >>> img5d = norm_dim(img3d) >>> img5d.shape torch.Size([1, 1, 50, 100, 100])

norm_type

torch.utils.norm_type(tensor, cuda=False, dtype=None, cuda_device='cuda:0')

Normalize tensor type, device, and precision.

This function converts the input tensor to the specified type, device, and precision, handling both NumPy arrays and PyTorch tensors seamlessly.

Args: tensor: Input tensor or array cuda: Whether to move tensor to CUDA device half: Whether to convert tensor to half precision (float16) dtype: Specific dtype to convert tensor to (overrides half) cuda_device: CUDA device to use if cuda=True

Returns: torch.Tensor: Normalized tensor with specified properties

Example: >>> import numpy as np >>> array = np.random.rand(100, 100, 50).astype(np.float32) >>> tensor = norm_type(array, cuda=True, half=True) >>> tensor.device, tensor.dtype (device(type=‘cuda’, index=0), torch.float16)