spacetransformer.torch.gpu_utils
torch.gpu_utils
GPU error handling utilities for SpaceTransformer PyTorch operations.
This module provides utilities for handling CUDA and GPU-related errors in PyTorch operations, converting low-level GPU errors into clear, actionable error messages for users.
Example: Handling CUDA errors in image processing:
>>> try:
... # Some GPU operation that might fail
... result = torch.cuda.operation()
... except RuntimeError as e:
... handle_cuda_error(e, "image warping")
CudaError: GPU out of memory during image warping. Try reducing batch size,
using smaller images, or switching to CPU processing.
Functions
| Name | Description |
|---|---|
| check_cuda_availability | Check if CUDA is available and provide helpful information if not. |
| get_gpu_memory_info | Get current GPU memory usage information. |
| handle_cuda_error | Convert CUDA errors to clear, informative error messages. |
| validate_tensor_device | Validate that a tensor is on the expected device. |
check_cuda_availability
torch.gpu_utils.check_cuda_availability()Check if CUDA is available and provide helpful information if not.
This function checks CUDA availability and provides detailed information about the GPU setup, helping users understand their hardware configuration.
Returns: bool: True if CUDA is available and working, False otherwise
Example: Checking CUDA before GPU operations:
>>> if check_cuda_availability():
... device = "cuda:0"
... else:
... device = "cpu"
... print("Using CPU processing")
get_gpu_memory_info
torch.gpu_utils.get_gpu_memory_info()Get current GPU memory usage information.
This function provides detailed information about GPU memory usage, which is helpful for debugging memory-related issues.
Returns: dict: Dictionary containing memory information in GB, or empty dict if CUDA unavailable
Example: Checking memory before large operations:
>>> memory_info = get_gpu_memory_info()
>>> if memory_info and memory_info['free'] < 2.0: # Less than 2GB free
... print("Warning: Low GPU memory available")
handle_cuda_error
torch.gpu_utils.handle_cuda_error(error, operation)Convert CUDA errors to clear, informative error messages.
This function analyzes CUDA runtime errors and converts them into user-friendly error messages with specific suggestions for resolution. It handles common GPU issues encountered in medical image processing.
Args: error: The original CUDA or PyTorch error operation: Description of the operation that failed (e.g., “image warping”)
Raises: CudaError: Always raises with clear error message and suggestions
Example: Converting out-of-memory errors:
>>> try:
... large_tensor = torch.zeros(10000, 10000, 10000).cuda()
... except RuntimeError as e:
... handle_cuda_error(e, "tensor allocation")
CudaError: GPU out of memory during tensor allocation. Try reducing batch size,
using smaller images, or switching to CPU processing.
validate_tensor_device
torch.gpu_utils.validate_tensor_device(tensor, expected_device)Validate that a tensor is on the expected device.
This function checks tensor device placement and raises a clear error if the tensor is on the wrong device.
Args: tensor: PyTorch tensor to check expected_device: Expected device string (e.g., “cuda:0”, “cpu”)
Raises: CudaError: If tensor is on wrong device
Example: Validating tensor placement:
>>> tensor = torch.rand(100, 100).cuda()
>>> validate_tensor_device(tensor, "cuda:0") # No error
>>> validate_tensor_device(tensor, "cpu") # Raises CudaError