spacetransformer.torch.validation
torch.validation
Validation utilities for PyTorch-specific operations.
This module provides validation functions for PyTorch tensors, GPU devices, and other parameters specific to GPU-accelerated operations in SpaceTransformer.
Example: Validating tensors in a PyTorch function:
>>> from spacetransformer.torch.validation import validate_tensor, validate_device
>>> def process_image(image, device="cuda:0"):
... image_tensor = validate_tensor(image, expected_dim=5)
... device = validate_device(device)
... # Proceed with validated inputs
Functions
| Name | Description |
|---|---|
| validate_device | Validate and return a PyTorch device. |
| validate_image_tensor | Validate medical image tensor for processing. |
| validate_interpolation_mode | Validate interpolation mode for grid_sample and other operations. |
| validate_padding_mode | Validate padding mode for grid_sample and other operations. |
| validate_tensor | Validate and optionally convert input to PyTorch tensor. |
validate_device
torch.validation.validate_device(device)Validate and return a PyTorch device.
Args: device: Device specification (string or torch.device)
Returns: torch.device: Validated PyTorch device
Raises: ValidationError: If device is invalid CudaError: If CUDA device is specified but not available
Example: >>> validate_device(“cpu”) device(type=‘cpu’) >>> validate_device(“cuda:0”) # Will raise CudaError if CUDA unavailable
validate_image_tensor
torch.validation.validate_image_tensor(
image,
min_dim=3,
max_dim=5,
dtype=None,
device=None,
name='image',
)Validate medical image tensor for processing.
Args: image: Input image as tensor, array, or array-like min_dim: Minimum allowed dimensions (default 3 for 3D medical images) max_dim: Maximum allowed dimensions (default 5 for batched images) dtype: Target data type (None to keep original) device: Target device (None to keep original) name: Parameter name for error messages
Returns: Tuple containing: - torch.Tensor: Validated image tensor - int: Original number of dimensions
Raises: ValidationError: If image is invalid
Example: >>> import torch >>> image = torch.rand(100, 100, 50) >>> tensor, ndim = validate_image_tensor(image) >>> print(ndim) 3
validate_interpolation_mode
torch.validation.validate_interpolation_mode(mode, name='mode')Validate interpolation mode for grid_sample and other operations.
Args: mode: Interpolation mode string name: Parameter name for error messages
Returns: str: Validated interpolation mode
Raises: ValidationError: If mode is invalid
Example: >>> validate_interpolation_mode(“trilinear”) ‘trilinear’ >>> validate_interpolation_mode(“invalid”) # Will raise ValidationError
validate_padding_mode
torch.validation.validate_padding_mode(mode, name='pad_mode')Validate padding mode for grid_sample and other operations.
Args: mode: Padding mode string name: Parameter name for error messages
Returns: str: Validated padding mode
Raises: ValidationError: If mode is invalid
Example: >>> validate_padding_mode(“zeros”) ‘zeros’ >>> validate_padding_mode(“invalid”) # Will raise ValidationError
validate_tensor
torch.validation.validate_tensor(
tensor,
expected_dim=None,
dtype=None,
device=None,
name='tensor',
)Validate and optionally convert input to PyTorch tensor.
Args: tensor: Input tensor or array-like object expected_dim: Expected number of dimensions (None to skip check) dtype: Target data type (None to keep original) device: Target device (None to keep original) name: Parameter name for error messages
Returns: torch.Tensor: Validated PyTorch tensor
Raises: ValidationError: If tensor is invalid
Example: >>> import numpy as np >>> image = np.random.rand(100, 100, 50) >>> tensor = validate_tensor(image, expected_dim=3) >>> tensor = validate_tensor(image, dtype=torch.float32, device=“cuda:0”)