spacetransformer.torch.validation

torch.validation

Validation utilities for PyTorch-specific operations.

This module provides validation functions for PyTorch tensors, GPU devices, and other parameters specific to GPU-accelerated operations in SpaceTransformer.

Example: Validating tensors in a PyTorch function:

>>> from spacetransformer.torch.validation import validate_tensor, validate_device
>>> def process_image(image, device="cuda:0"):
...     image_tensor = validate_tensor(image, expected_dim=5)
...     device = validate_device(device)
...     # Proceed with validated inputs

Functions

Name Description
validate_device Validate and return a PyTorch device.
validate_image_tensor Validate medical image tensor for processing.
validate_interpolation_mode Validate interpolation mode for grid_sample and other operations.
validate_padding_mode Validate padding mode for grid_sample and other operations.
validate_tensor Validate and optionally convert input to PyTorch tensor.

validate_device

torch.validation.validate_device(device)

Validate and return a PyTorch device.

Args: device: Device specification (string or torch.device)

Returns: torch.device: Validated PyTorch device

Raises: ValidationError: If device is invalid CudaError: If CUDA device is specified but not available

Example: >>> validate_device(“cpu”) device(type=‘cpu’) >>> validate_device(“cuda:0”) # Will raise CudaError if CUDA unavailable

validate_image_tensor

torch.validation.validate_image_tensor(
    image,
    min_dim=3,
    max_dim=5,
    dtype=None,
    device=None,
    name='image',
)

Validate medical image tensor for processing.

Args: image: Input image as tensor, array, or array-like min_dim: Minimum allowed dimensions (default 3 for 3D medical images) max_dim: Maximum allowed dimensions (default 5 for batched images) dtype: Target data type (None to keep original) device: Target device (None to keep original) name: Parameter name for error messages

Returns: Tuple containing: - torch.Tensor: Validated image tensor - int: Original number of dimensions

Raises: ValidationError: If image is invalid

Example: >>> import torch >>> image = torch.rand(100, 100, 50) >>> tensor, ndim = validate_image_tensor(image) >>> print(ndim) 3

validate_interpolation_mode

torch.validation.validate_interpolation_mode(mode, name='mode')

Validate interpolation mode for grid_sample and other operations.

Args: mode: Interpolation mode string name: Parameter name for error messages

Returns: str: Validated interpolation mode

Raises: ValidationError: If mode is invalid

Example: >>> validate_interpolation_mode(“trilinear”) ‘trilinear’ >>> validate_interpolation_mode(“invalid”) # Will raise ValidationError

validate_padding_mode

torch.validation.validate_padding_mode(mode, name='pad_mode')

Validate padding mode for grid_sample and other operations.

Args: mode: Padding mode string name: Parameter name for error messages

Returns: str: Validated padding mode

Raises: ValidationError: If mode is invalid

Example: >>> validate_padding_mode(“zeros”) ‘zeros’ >>> validate_padding_mode(“invalid”) # Will raise ValidationError

validate_tensor

torch.validation.validate_tensor(
    tensor,
    expected_dim=None,
    dtype=None,
    device=None,
    name='tensor',
)

Validate and optionally convert input to PyTorch tensor.

Args: tensor: Input tensor or array-like object expected_dim: Expected number of dimensions (None to skip check) dtype: Target data type (None to keep original) device: Target device (None to keep original) name: Parameter name for error messages

Returns: torch.Tensor: Validated PyTorch tensor

Raises: ValidationError: If tensor is invalid

Example: >>> import numpy as np >>> image = np.random.rand(100, 100, 50) >>> tensor = validate_tensor(image, expected_dim=3) >>> tensor = validate_tensor(image, dtype=torch.float32, device=“cuda:0”)