spacetransformer.torch.utils
torch.utils
Utility functions for PyTorch operations in SpaceTransformer.
This module provides utility functions for PyTorch operations, including tensor dimension normalization, type conversion, and device management.
Functions
| Name | Description |
|---|---|
| norm_dim | Normalize tensor dimensions to 5D (batch, channel, depth, height, width). |
| norm_type | Normalize tensor type, device, and precision. |
norm_dim
torch.utils.norm_dim(tensor)Normalize tensor dimensions to 5D (batch, channel, depth, height, width).
This function converts input tensors of various dimensions to a standard 5D format used in medical image processing. This simplifies operations by ensuring consistent dimension ordering.
Args: tensor: Input tensor of dimensions 3D, 4D, or 5D - 3D: interpreted as (depth, height, width) - 4D: interpreted as (channel, depth, height, width) - 5D: interpreted as (batch, channel, depth, height, width)
Returns: torch.Tensor: Normalized 5D tensor
Raises: ValueError: If input dimensions are invalid (< 3D or > 5D)
Example: >>> import torch >>> img3d = torch.rand(50, 100, 100) # D,H,W >>> img5d = norm_dim(img3d) >>> img5d.shape torch.Size([1, 1, 50, 100, 100])
norm_type
torch.utils.norm_type(tensor, cuda=False, dtype=None, cuda_device='cuda:0')Normalize tensor type, device, and precision.
This function converts the input tensor to the specified type, device, and precision, handling both NumPy arrays and PyTorch tensors seamlessly.
Args: tensor: Input tensor or array cuda: Whether to move tensor to CUDA device half: Whether to convert tensor to half precision (float16) dtype: Specific dtype to convert tensor to (overrides half) cuda_device: CUDA device to use if cuda=True
Returns: torch.Tensor: Normalized tensor with specified properties
Example: >>> import numpy as np >>> array = np.random.rand(100, 100, 50).astype(np.float32) >>> tensor = norm_type(array, cuda=True, half=True) >>> tensor.device, tensor.dtype (device(type=‘cuda’, index=0), torch.float16)