import logging
import math
from typing import Tuple
import torch
[docs]
logger = logging.getLogger(__name__)
[docs]
def get_device():
"""
Utility function to get the currently available PyTorch device.
Returns
-------
torch.device
Available torch device (either cuda or cpu).
"""
try:
# Try to get CUDA device count to check if CUDA is properly initialized
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return torch.device("cuda:0")
else:
return torch.device("cpu")
except Exception as e:
logger.warning(f"CUDA error: {e}. Falling back to CPU.")
return torch.device("cpu")
[docs]
def gradient2d(tensor):
"""
Compute gradients (dy, dx) of a 2D tensor using centered differences.
"""
dx = (
torch.roll(tensor, shifts=-1, dims=-1) - torch.roll(tensor, shifts=1, dims=-1)
) / 2.0
dy = (
torch.roll(tensor, shifts=-1, dims=-2) - torch.roll(tensor, shifts=1, dims=-2)
) / 2.0
# fix edges with forward/backward differences
dx[..., 0] = tensor[..., 1] - tensor[..., 0]
dx[..., -1] = tensor[..., -1] - tensor[..., -2]
dy[..., 0, :] = tensor[..., 1, :] - tensor[..., 0, :]
dy[..., -1, :] = tensor[..., -1, :] - tensor[..., -2, :]
return dy, dx
[docs]
def estimate_memory_per_pixel(reproject_instance, interpolation_mode: str) -> float:
"""
Estimate memory usage per output pixel in bytes.
Parameters
----------
reproject_instance : Reproject
An initialized Reproject instance.
interpolation_mode : str
Interpolation mode being used.
Returns
-------
float
Estimated bytes per pixel.
"""
B = len(reproject_instance.batch_source_images)
# Base memory for coordinates and intermediate tensors
# Each pixel needs:
# - x_source, y_source: 2 * 8 bytes (float64)
# - x_normalized, y_normalized: 2 * 8 bytes
# - combined_result: 2 channels * 8 bytes
# - result: 8 bytes
# - valid_pixels mask: 1 byte
base_per_pixel = (2 + 2 + 2 + 1 + 1) * 8
# Add memory for Jacobian if needed
if reproject_instance.compute_jacobian and reproject_instance.conserve_flux:
# 4 gradient tensors: dx_x, dy_x, dx_y, dy_y
base_per_pixel += 4 * 8
# Multiply by batch size
total_per_pixel = base_per_pixel * B
# Add overhead for intermediate operations (20% extra)
total_per_pixel *= 1.2
return total_per_pixel
[docs]
def calculate_chunk_size(
reproject_instance,
output_shape: Tuple[int, int],
max_memory_mb: float,
safety_factor: float,
interpolation_mode: str = "bilinear",
) -> Tuple[int, int]:
"""
Calculate optimal chunk size based on memory constraints.
Parameters
----------
reproject_instance : Reproject
An initialized Reproject instance.
output_shape : Tuple[int, int]
Shape of the output image (H, W).
max_memory_mb : float
Maximum memory to use in megabytes.
safety_factor : float
Safety factor (0-1) for memory calculation.
interpolation_mode : str
Interpolation mode to use.
Returns
-------
Tuple[int, int]
Chunk size (chunk_height, chunk_width).
"""
H, W = output_shape
# Available memory in bytes
available_bytes = max_memory_mb * 1024 * 1024 * safety_factor
# Memory per pixel
bytes_per_pixel = estimate_memory_per_pixel(reproject_instance, interpolation_mode)
# Maximum pixels per chunk
max_pixels = int(available_bytes / bytes_per_pixel)
# Ensure at least 1 row can be processed
if max_pixels < W:
logger.warning(
f"Memory limit very tight! Estimated {bytes_per_pixel:.2f} bytes/pixel. "
f"Consider increasing max_memory_mb (currently {max_memory_mb} MB)."
)
max_pixels = W
# Calculate chunk dimensions
# Try to make chunks roughly square for better cache performance
chunk_height = min(H, int(math.sqrt(max_pixels * H / W)))
chunk_height = max(1, chunk_height) # At least 1 row
# Calculate corresponding width
chunk_width = min(W, max_pixels // chunk_height)
chunk_width = max(1, chunk_width)
return chunk_height, chunk_width
[docs]
def process_chunk(
reproject_instance,
y_start: int,
y_end: int,
x_start: int,
x_end: int,
interpolation_mode: str = "bilinear",
) -> torch.Tensor:
"""
Process a single chunk of the reprojection.
Parameters
----------
reproject_instance : Reproject
An initialized Reproject instance.
y_start, y_end : int
Y-axis range for the chunk.
x_start, x_end : int
X-axis range for the chunk.
interpolation_mode : str
Interpolation mode.
Returns
-------
torch.Tensor
Reprojected chunk.
"""
# Create chunk-specific grid
B = len(reproject_instance.batch_source_images)
chunk_h = y_end - y_start
chunk_w = x_end - x_start
y_chunk = (
torch.arange(
y_start, y_end, dtype=torch.float64, device=reproject_instance.device
)
.view(1, -1, 1)
.expand(B, chunk_h, chunk_w)
.clone() # Force a copy to avoid hidden reference issues
)
x_chunk = (
torch.arange(
x_start, x_end, dtype=torch.float64, device=reproject_instance.device
)
.view(1, 1, -1)
.expand(B, chunk_h, chunk_w)
.clone() # Force a copy to avoid hidden reference issues
)
# Temporarily override the target grid
original_grid = reproject_instance.target_grid
reproject_instance.target_grid = (y_chunk, x_chunk)
try:
# Process this chunk
chunk_result = reproject_instance.interpolate_source_image(
interpolation_mode=interpolation_mode
)
finally:
# Restore original grid
reproject_instance.target_grid = original_grid
# Explicitly delete chunk tensors to free memory
del y_chunk, x_chunk
# Clear cache for both CPU and CUDA
if reproject_instance.device.type == "cuda":
torch.cuda.empty_cache()
return chunk_result
[docs]
def reproject_chunked(
reproject_instance,
max_memory_mb: float,
safety_factor: float,
interpolation_mode: str = "bilinear",
show_progress: bool = True,
) -> torch.Tensor:
"""
Perform chunked reprojection.
Parameters
----------
reproject_instance : Reproject
An initialized Reproject instance.
max_memory_mb : float
Maximum memory to use in megabytes.
safety_factor : float
Safety factor for memory calculation.
interpolation_mode : str
Interpolation mode to use.
show_progress : bool
Whether to log progress information.
Returns
-------
torch.Tensor
Full reprojected image.
"""
y_grid, x_grid = reproject_instance.target_grid
B, H, W = y_grid.shape
# Calculate chunk size
chunk_h, chunk_w = calculate_chunk_size(
reproject_instance, (H, W), max_memory_mb, safety_factor, interpolation_mode
)
# Calculate number of chunks
n_chunks_y = math.ceil(H / chunk_h)
n_chunks_x = math.ceil(W / chunk_w)
total_chunks = n_chunks_y * n_chunks_x
# Initialize output
result = torch.full(
(B, H, W), torch.nan, dtype=torch.float64, device=reproject_instance.device
)
# Process chunks
chunk_idx = 0
for i in range(n_chunks_y):
y_start = i * chunk_h
y_end = min((i + 1) * chunk_h, H)
for j in range(n_chunks_x):
x_start = j * chunk_w
x_end = min((j + 1) * chunk_w, W)
chunk_idx += 1
if show_progress:
logger.info(f"Processing chunk {chunk_idx}/{total_chunks}")
# Process chunk
chunk_result = process_chunk(
reproject_instance, y_start, y_end, x_start, x_end, interpolation_mode
)
# Insert into result
result[:, y_start:y_end, x_start:x_end] = chunk_result
# Explicitly delete chunk result to free memory immediately
del chunk_result
# Periodic cache clearing every 10 chunks to prevent accumulation
if chunk_idx % 10 == 0:
if reproject_instance.device.type == "cuda":
torch.cuda.empty_cache()
# Final cache clear
if reproject_instance.device.type == "cuda":
torch.cuda.empty_cache()
return result