import torch
import numpy as np
[docs]
def lanczos_kernel(x, a=3):
"""
Lanczos kernel function with parameter a (number of lobes).
Args:
x: Distance from center
a: Number of lobes (typically 3 for Lanczos-3)
Returns:
Kernel weights
"""
x = torch.abs(x)
# Lanczos kernel: sinc(x) * sinc(x/a) for |x| < a, 0 otherwise
mask = x < a
result = torch.zeros_like(x)
# Avoid division by zero
x_nonzero = x[mask]
x_nonzero = torch.where(x_nonzero == 0, torch.tensor(1e-8, dtype=x.dtype, device=x.device), x_nonzero)
# sinc(x) = sin(π*x) / (π*x)
sinc_x = torch.sin(np.pi * x_nonzero) / (np.pi * x_nonzero)
sinc_x_a = torch.sin(np.pi * x_nonzero / a) / (np.pi * x_nonzero / a)
result[mask] = sinc_x * sinc_x_a
# Handle x=0 case (sinc(0) = 1)
result[x == 0] = 1.0
return result
def _process_chunk_vectorized(img, target_x, target_y, radius, dtype, device):
"""
Vectorized processing of a chunk of output pixels.
Args:
img: Source image tensor (H, W)
target_x: Target x coordinates (chunk_h, chunk_w)
target_y: Target y coordinates (chunk_h, chunk_w)
radius: Lanczos radius
dtype: Data type
device: Device
Returns:
Interpolated chunk (chunk_h, chunk_w)
"""
H, W = img.shape
chunk_h, chunk_w = target_x.shape
# Initialize accumulators
result = torch.zeros_like(target_x, dtype=dtype, device=device)
total_weights = torch.zeros_like(target_x, dtype=dtype, device=device)
# For each source pixel offset within the kernel support
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
# Source pixel coordinates for entire chunk - THIS IS THE KEY FIX
# We need to use floor of target coordinates as the base, not round
src_y = torch.floor(target_y).long() + dy # (chunk_h, chunk_w)
src_x = torch.floor(target_x).long() + dx # (chunk_h, chunk_w)
# Create bounds mask
mask = (src_y >= 0) & (src_y < H) & (src_x >= 0) & (src_x < W)
if not mask.any():
continue # No valid pixels for this offset
# Calculate Lanczos weights for this offset
# Use actual source pixel coordinates vs target coordinates
weight_x = lanczos_kernel(src_x.float() - target_x, a=3)
weight_y = lanczos_kernel(src_y.float() - target_y, a=3)
weight = weight_x * weight_y
# Apply bounds mask
weight_masked = weight * mask.float()
# Gather source pixel values (only where mask is True)
valid_indices = mask.nonzero(as_tuple=True)
if len(valid_indices[0]) > 0:
# Get values at valid locations
src_y_valid = src_y[valid_indices]
src_x_valid = src_x[valid_indices]
values = img[src_y_valid, src_x_valid]
# Create full values tensor
values_full = torch.zeros_like(result)
values_full[valid_indices] = values
# Accumulate weighted contribution
result += weight_masked * values_full
total_weights += weight_masked
# Normalize by total weights (avoid division by zero)
mask_nonzero = total_weights > 0
result[mask_nonzero] = result[mask_nonzero] / total_weights[mask_nonzero]
return result
[docs]
def lanczos_grid_sample(source_image, grid, padding_mode="zeros", chunk_size=1024):
"""
Memory-efficient vectorized LANCZOS-3 interpolation for large astronomical images.
Args:
source_image: Input tensor of shape (N, C, H, W)
grid: Grid tensor of shape (N, H_out, W_out, 2) with values in [-1, 1]
padding_mode: Only "zeros" supported for now
chunk_size: Size of chunks to process at once (for memory management)
Returns:
Interpolated tensor of shape (N, C, H_out, W_out)
"""
N, C, H, W = source_image.shape
_, H_out, W_out, _ = grid.shape
device = source_image.device
dtype = source_image.dtype
# Convert grid from [-1, 1] to pixel coordinates
grid_x = ((grid[..., 0] + 1) / 2) * (W - 1) # Shape: (N, H_out, W_out)
grid_y = ((grid[..., 1] + 1) / 2) * (H - 1) # Shape: (N, H_out, W_out)
# Initialize output
output = torch.zeros(N, C, H_out, W_out, dtype=dtype, device=device)
radius = 3
# Process each batch and channel
for n in range(N):
for c in range(C):
img = source_image[n, c] # Shape: (H, W)
# Process output in chunks to manage memory
for i_start in range(0, H_out, chunk_size):
i_end = min(i_start + chunk_size, H_out)
for j_start in range(0, W_out, chunk_size):
j_end = min(j_start + chunk_size, W_out)
# Get chunk coordinates
chunk_x = grid_x[n, i_start:i_end, j_start:j_end] # (chunk_h, chunk_w)
chunk_y = grid_y[n, i_start:i_end, j_start:j_end] # (chunk_h, chunk_w)
chunk_h, chunk_w = chunk_x.shape
# Vectorized processing for this chunk
chunk_result = _process_chunk_vectorized(
img, chunk_x, chunk_y, radius, dtype, device
)
# Store result
output[n, c, i_start:i_end, j_start:j_end] = chunk_result
return output