Source code for dfreproject.reproject

import logging
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from astropy.io.fits import Header, PrimaryHDU
from astropy.wcs import WCS

from .lanczos import lanczos_grid_sample
from .sip import apply_inverse_sip_distortion, apply_sip_distortion, get_sip_coeffs
from .tensorhdu import TensorHDU
from .utils import get_device, gradient2d, reproject_chunked

[docs] logger = logging.getLogger(__name__)
[docs] EPSILON = 1e-10
[docs] VALID_ORDERS = ["bicubic", "bilinear", "nearest", "nearest-neighbors", "lanczos"]
[docs] def validate_interpolation_order(order: str) -> str: """ Function to validate the requested interpolation order. The order must be one of the following: "bicubic", "bilinear", "nearest-neighbors". "nearest" is an alias for "nearest-neighbors". Parameters ---------- order : str Interpolation order to validate. Returns ------- str Validated interpolation order. Raises ------ ValueError When the provided order is not one of the valid interpolation orders. """ if order not in VALID_ORDERS: raise ValueError(f"order must be one of: {', '.join(VALID_ORDERS)}") elif order == "nearest-neighbors": return "nearest" else: return order
# Helper functions for trigonometric calculations
[docs] def atan2d(y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ PyTorch implementation of WCSLib's atan2d function. Parameters ---------- y : torch.Tensor y coordinate(s). x : torch.Tensor x coordinate(s). Returns ------- torch.Tensor atan2d(y, x) in degrees. """ return torch.rad2deg(torch.atan2(y, x))
[docs] def sincosd(angle_deg: torch.Tensor) -> torch.Tensor: """ PyTorch implementation of WCSLib's sincosd function. Parameters ---------- angle_deg : torch.Tensor angle in degrees. Returns ------- tuple(torch.Tensor, torch.Tensor) sin(angle) in degrees, cos(angle) in degrees. """ angle_rad = torch.deg2rad(angle_deg) return torch.sin(angle_rad), torch.cos(angle_rad)
[docs] def interpolate_image( source_image: torch.Tensor, grid: torch.Tensor, interpolation_mode: str ) -> torch.Tensor: """ Image interpolation using grid_sample with LANCZOS support. Parameters ---------- source_image : torch.Tensor Source image to interpolate. grid : torch.Tensor Grid on which to interpolate. interpolation_mode: str Interpolation mode to use. Supports PyTorch's built-in modes ('bilinear', 'bicubic', 'nearest') plus 'lanczos' for LANCZOS-3. Returns ------- torch.Tensor Interpolated image. """ if interpolation_mode == "lanczos": return lanczos_grid_sample(source_image, grid, padding_mode="zeros") else: return torch.nn.functional.grid_sample( source_image, grid, mode=interpolation_mode, align_corners=True, padding_mode="zeros", )
[docs] class Reproject: def __init__( self, source_hdus: List[PrimaryHDU], target_wcs: WCS, shape_out: Tuple[int, int], device: str = None, num_threads: int = None, requires_grad: bool = False, conserve_flux: bool = True, compute_jacobian: bool = True, ): """ Initialize a dfreproject operation between source and target image frames. This constructor sets up the necessary components for reprojecting an astronomical image from one World Coordinate System (WCS) to another. It stores the source and target WCS information, images, and creates a coordinate grid for the target image that will be used in the dfreproject process. Parameters ---------- source_hdus : List[PrimaryHDU] List of HDUs containing the data and the header information for the source image. target_wcs : WCS WCS for the target in an astropy.wcs compatible format. shape_out: Tuple[int, int] Shape of the output image. device: str Device to use for computations. Defaults to GPU if available otherwise uses CPU. num_threads: int Number of threads to use on CPU. conserve_flux: bool If True, enables flux conservation through footprint calculation. compute_jacobian: bool, optional If True, enables non-linear flux conservation through Jacobian calculation. Note that this increases RAM usage. Notes ----- This constructor creates a coordinate grid spanning the entire target image, which will be used for the pixel-to-world and world-to-pixel transformations during dfreproject. The grid is created with 'ij' indexing, where the first dimension corresponds to y (rows) and the second to x (columns). The coordinate grid is stored as a tuple of tensors (batch, y_grid, x_grid), where each element has the same shape as the target image. Examples -------- >>> # Initialize the dfreproject object >>> reproject = Reproject(source_hdus, target_wcs) """ # Set device if device is None: self.device = get_device() else: self.device = torch.device(device) if num_threads: torch.set_num_threads(num_threads)
[docs] self.requires_grad = requires_grad
[docs] self.batch_source_images = self._prepare_source_images(source_hdus)
# Initialize the WCS objects
[docs] self.batch_source_wcs_params = self._prepare_batch_wcs_params(source_hdus)
[docs] self.target_wcs_params = self._extract_wcs_params(target_wcs)
[docs] self.target_wcs = target_wcs
# Define target grid
[docs] self.target_grid = self._create_batch_target_grid(shape_out)
# Define flux conservation booleans
[docs] self.conserve_flux = conserve_flux
[docs] self.compute_jacobian = compute_jacobian
def _prepare_source_images(self, source_hdus: List[PrimaryHDU]) -> torch.Tensor: """ Prepare batch of source images as a single tensor. Parameters ---------- source_hdus : List[PrimaryHDU] List of HDUs containing the data and the header information for the source image. Returns ------- source_image : torch.Tensor Stack of source image tensors. """ try: source_images = [] for hdu in source_hdus: if self.requires_grad and isinstance(hdu, TensorHDU): img = hdu.tensor.to(self.device) else: img = torch.tensor( hdu.data, dtype=torch.float64, device=self.device ) source_images.append(img) except ValueError: # In case there is a byte order error source_images = [ torch.tensor( np.asarray(hdu.data, dtype=np.float64).copy(), dtype=torch.float64, device=self.device, ) for hdu in source_hdus ] return torch.stack(source_images) def _extract_wcs_params(self, wcs: WCS) -> dict: """ Extract key WCS parameters into a dictionary for efficient tensor operations. Returns a dictionary with pre-computed tensor parameters. Parameters ---------- wcs : WCS WCS information. Returns ------- wcs_params : dict WCS parameters. """ return { "crpix": torch.tensor( wcs.wcs.crpix, dtype=torch.float64, device=self.device ), "crval": torch.tensor( wcs.wcs.crval, dtype=torch.float64, device=self.device ), "pc_matrix": torch.tensor( wcs.wcs.get_pc(), dtype=torch.float64, device=self.device ), "cdelt": torch.tensor( wcs.wcs.cdelt, dtype=torch.float64, device=self.device ), "sip_coeffs": get_sip_coeffs(wcs), } def _prepare_batch_wcs_params( self, source_hdus: Union[List[PrimaryHDU], List[TensorHDU]] ) -> List[dict]: """ Prepare batch of WCS parameters. Parameters ---------- source_hdus : List[PrimaryHDU] List of HDUs containing the data and the header information for the source image. Returns ------- List[dict] List of dictionaries containing the WCS parameters extracted from each HDU. """ return [self._extract_wcs_params(WCS(hdu.header)) for hdu in source_hdus] def _create_batch_target_grid( self, shape_out: Tuple[int, int] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create a batched target grid matching the number of source images. Parameters ---------- shape_out : Tuple[int, int] Shape of the output image. """ B = len(self.batch_source_images) H, W = shape_out # Create base grids once y_base = torch.arange(H, dtype=torch.float64, device=self.device) x_base = torch.arange(W, dtype=torch.float64, device=self.device) # Use broadcasting instead of repeat to save memory during creation # expand() creates a view without allocating new memory y_grid = y_base.view(1, -1, 1).expand(B, H, W) x_grid = x_base.view(1, 1, -1).expand(B, H, W) # Only make contiguous copies if needed for operations # Most operations work fine with non-contiguous tensors return y_grid, x_grid
[docs] def calculate_skyCoords( self, x_grid=None, y_grid=None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate sky coordinates. There are four primary steps: 1. Apply shift 2. Apply SIP distortion 3. Apply CD matrix 4. Apply transformation to celestial coordinates uing the Gnomonic Projection These steps use the target wcs parameters. Parameters ---------- x_grid : torch.Tensor, optional Batch of x-coordinates. If None, uses target grid x-coordinates. y_grid : torch.Tensor, optional Batch of y-coordinates. If None, uses target grid y-coordinates. Returns ------- tuple Batched RA and Dec coordinates. """ if x_grid is None or y_grid is None: y_grid, x_grid = self.target_grid else: if not isinstance(x_grid, torch.Tensor): x_grid = torch.tensor(x_grid, dtype=torch.float64, device=self.device) if not isinstance(y_grid, torch.Tensor): y_grid = torch.tensor(y_grid, dtype=torch.float64, device=self.device) if x_grid.dim() == 2: x_grid = x_grid.unsqueeze(0) if y_grid.dim() == 2: y_grid = y_grid.unsqueeze(0) B, H, W = y_grid.shape # Unpack target WCS parameters crpix = self.target_wcs_params["crpix"] crval = self.target_wcs_params["crval"] pc_matrix = self.target_wcs_params["pc_matrix"] cdelt = self.target_wcs_params["cdelt"] sip_coeffs = self.target_wcs_params["sip_coeffs"] # Compute pixel offsets (in-place when possible) u = x_grid - (crpix[0] - 1) v = y_grid - (crpix[1] - 1) # Apply SIP distortion if present if sip_coeffs is not None: u, v = apply_sip_distortion(u, v, sip_coeffs, self.device) # Apply PC Matrix and CDELT CD_matrix = pc_matrix * cdelt # Optimize matrix multiplication pixel_offsets = torch.stack([u.reshape(B, -1), v.reshape(B, -1)], dim=-1) del u, v # Free immediately transformed = torch.bmm( pixel_offsets, CD_matrix.T.unsqueeze(0).expand(B, -1, -1) ) del pixel_offsets, CD_matrix x_scaled = transformed[:, :, 0].reshape(B, H, W) y_scaled = transformed[:, :, 1].reshape(B, H, W) del transformed # Compute radial distance (in-place operations) r = torch.sqrt(x_scaled.pow(2).add_(y_scaled.pow(2))) # More memory efficient r0 = torch.tensor(180.0 / torch.pi) # Compute phi efficiently phi = torch.zeros_like(r) non_zero_r = r > 0 # Avoid creating unnecessary boolean tensor phi[non_zero_r] = torch.rad2deg( torch.atan2(-x_scaled[non_zero_r], y_scaled[non_zero_r]) ) del x_scaled, y_scaled, non_zero_r phi_rad = torch.deg2rad(phi) del phi theta_rad = torch.atan2( r0, r ) # Direct computation without intermediate conversions del r # Pre-compute trig values ra0_rad = crval[0] * (torch.pi / 180.0) dec0_rad = crval[1] * (torch.pi / 180.0) sin_theta = torch.sin(theta_rad) cos_theta = torch.cos(theta_rad) sin_phi = torch.sin(phi_rad) cos_phi = torch.cos(phi_rad) del theta_rad, phi_rad sin_dec0 = torch.sin(dec0_rad) cos_dec0 = torch.cos(dec0_rad) # Compute dec sin_dec = sin_theta * sin_dec0 + cos_theta * cos_dec0 * cos_phi dec_rad = torch.arcsin(sin_dec) del sin_dec # Compute ra (reuse tensors where possible) ra_rad = ra0_rad + torch.atan2( -cos_theta * sin_phi, sin_theta * cos_dec0 - cos_theta * sin_dec0 * cos_phi ) del sin_theta, cos_theta, sin_phi, cos_phi, sin_dec0, cos_dec0 # Convert to degrees ra = torch.rad2deg(ra_rad) % 360.0 dec = torch.rad2deg(dec_rad) del ra_rad, dec_rad return ra, dec
[docs] def calculate_sourceCoords(self): """ Calculate source image pixel coordinates corresponding to each target image pixel. This function repeats the same steps in self.calculate_skyCoords() except in the opposite order and with the source coordinate wcs. Returns ------- torch.Tensor Batch of source image pixel coordinates. """ B = len(self.batch_source_images) y_grid, x_grid = self.target_grid _, H, W = y_grid.shape # Pre-allocate output tensors batch_x_pixel = torch.zeros((B, H, W), dtype=torch.float64, device=self.device) batch_y_pixel = torch.zeros((B, H, W), dtype=torch.float64, device=self.device) # Process each source image's coordinates for b in range(B): # Calculate sky coords for just this batch element # Extract single batch element from grid x_grid_b = x_grid[b : b + 1] # Keep batch dimension y_grid_b = y_grid[b : b + 1] ra, dec = self.calculate_skyCoords(x_grid_b, y_grid_b) ra = ra.squeeze(0) # Remove batch dimension dec = dec.squeeze(0) # Free grid slices immediately del x_grid_b, y_grid_b # Get WCS parameters for this specific source image source_wcs_params = self.batch_source_wcs_params[b] crpix = source_wcs_params["crpix"] crval = source_wcs_params["crval"] pc_matrix = source_wcs_params["pc_matrix"] cdelt = source_wcs_params["cdelt"] sip_coeffs = source_wcs_params["sip_coeffs"] # Conversion calculations ra_rad = torch.deg2rad(ra) dec_rad = torch.deg2rad(dec) ra0_rad = crval[0] * torch.pi / 180.0 dec0_rad = crval[1] * torch.pi / 180.0 # Convert from world to native spherical coordinates y_phi = -torch.cos(dec_rad) * torch.sin(ra_rad - ra0_rad) x_phi = torch.sin(dec_rad) * torch.cos(dec0_rad) - torch.cos( dec_rad ) * torch.sin(dec0_rad) * torch.cos(ra_rad - ra0_rad) phi = torch.rad2deg(torch.atan2(y_phi, x_phi)) del x_phi, y_phi theta = torch.rad2deg( torch.arcsin( torch.sin(dec_rad) * torch.sin(dec0_rad) + torch.cos(dec_rad) * torch.cos(dec0_rad) * torch.cos(ra_rad - ra0_rad) ) ) del ra_rad, dec_rad, ra, dec # Apply TAN projection sin_phi, cos_phi = ( torch.sin(torch.deg2rad(phi)), torch.cos(torch.deg2rad(phi)), ) del phi sin_theta, cos_theta = ( torch.sin(torch.deg2rad(theta)), torch.cos(torch.deg2rad(theta)), ) del theta # Check for singularity eps = 1e-10 if torch.any(torch.abs(sin_theta) < eps): raise ValueError("Singularity in tans2x: theta close to 0 degrees") r0 = torch.tensor(180.0 / torch.pi, device=self.device) r = r0 * cos_theta / sin_theta del cos_theta, sin_theta, r0 x_scaled = -r * sin_phi y_scaled = r * cos_phi del sin_phi, cos_phi, r # Apply inverse CD matrix CD_matrix = pc_matrix * cdelt CD_inv = torch.linalg.inv(CD_matrix) del CD_matrix # Batch matrix multiplication x_scaled_flat = x_scaled.reshape(-1) y_scaled_flat = y_scaled.reshape(-1) del x_scaled, y_scaled standard_coords = torch.stack([x_scaled_flat, y_scaled_flat], dim=1) del x_scaled_flat, y_scaled_flat pixel_offsets = torch.matmul(standard_coords, CD_inv.T) u = pixel_offsets[:, 0].reshape(H, W) v = pixel_offsets[:, 1].reshape(H, W) del CD_inv, pixel_offsets, standard_coords if sip_coeffs is not None: u, v = apply_inverse_sip_distortion(u, v, sip_coeffs, self.device) # Add reference pixel batch_x_pixel[b] = u + (crpix[0] - 1) batch_y_pixel[b] = v + (crpix[1] - 1) del u, v # Clear references to WCS parameters to aid garbage collection del crpix, crval, pc_matrix, cdelt, sip_coeffs, source_wcs_params return batch_x_pixel, batch_y_pixel
[docs] def compute_pixel_map(self): """ Compute and return source-space pixel coordinates for reprojection. This method exposes the pixel mapping step used by :meth:`interpolate_source_image`, allowing callers to inspect, cache, or reuse the coordinate transform without immediately interpolating image values. Returns ------- Tuple[torch.Tensor, torch.Tensor] ``(x_source, y_source)`` pixel coordinates in source-image index space (unnormalized; not in ``grid_sample`` [-1, 1] coordinates). """ return self.calculate_sourceCoords()
[docs] def apply_pixel_map( self, pixel_map: Tuple[torch.Tensor, torch.Tensor], interpolation_mode: str = "bilinear", ) -> torch.Tensor: """ Interpolate the source image using a precomputed pixel map. Use this method with a map returned by :meth:`compute_pixel_map` when you want to reuse the same coordinate transform across repeated interpolation calls. Parameters ---------- pixel_map : Tuple[torch.Tensor, torch.Tensor] ``(x_source, y_source)`` coordinates from :meth:`compute_pixel_map`. interpolation_mode : str, default 'bilinear' Interpolation algorithm. Options: ``'nearest'``, ``'bilinear'``, ``'bicubic'``, ``'lanczos'``. Returns ------- torch.Tensor Reprojected image on the target grid. """ x_source, y_source = pixel_map B, H, W = self.batch_source_images.shape x_normalized = 2.0 * (x_source / (W - 1)) - 1.0 y_normalized = 2.0 * (y_source / (H - 1)) - 1.0 source_images = self.batch_source_images.unsqueeze(1) ones = torch.ones_like(source_images) combined_result = interpolate_image( torch.cat([source_images, ones], dim=1), torch.stack([x_normalized, y_normalized], dim=-1), interpolation_mode, ) del source_images, ones, x_normalized, y_normalized result = torch.full_like(combined_result[:, 0].squeeze(), torch.nan) valid_pixels = combined_result[:, 1].squeeze() > EPSILON if torch.any(valid_pixels): if self.conserve_flux: result[valid_pixels] = ( combined_result[:, 0].squeeze()[valid_pixels] / combined_result[:, 1].squeeze()[valid_pixels] ) del combined_result if self.compute_jacobian: dy_x, dx_x = gradient2d(x_source) dy_y, dx_y = gradient2d(y_source) jacobian_det = dx_x * dy_y del dy_y jacobian_det -= dy_x * dx_y del dx_x, dy_x, dx_y result[valid_pixels] *= jacobian_det.squeeze(0)[valid_pixels] del jacobian_det else: result[valid_pixels] = combined_result[:, 0].squeeze()[valid_pixels] del combined_result else: result = combined_result[:, 0].squeeze() / combined_result[:, 1].squeeze() del combined_result logger.warning( "No valid pixels found in footprint! Using raw interpolated values" ) del valid_pixels return result
[docs] def interpolate_source_image(self, interpolation_mode="bilinear") -> torch.Tensor: """ Interpolate the source image at the calculated source coordinates with flux conservation. This is a convenience wrapper around: 1. :meth:`compute_pixel_map` 2. :meth:`apply_pixel_map` If you need access to the pixel map itself, call :meth:`compute_pixel_map` directly and pass it to :meth:`apply_pixel_map`. This method performs the actual pixel resampling needed for dfreproject while preserving the total flux (photometric accuracy) by using a footprint correction and the Jacobian of the transformation. The method uses a combined tensor approach for computational efficiency, performing both image resampling and footprint tracking in a single operation. Total flux is conserved locally (via footprint correction and the Jacobian SIP calculation). Parameters ---------- interpolation_mode : str, default 'bilinear' The interpolation mode to use when sampling the source image. Options include: - 'nearest' : Nearest neighbor interpolation (no interpolation) - 'bilinear' : Bilinear interpolation (default) - 'bicubic' : Bicubic interpolation - 'lanczos' : Lanczos interpolation These correspond to the modes available in torch.nn.functional.grid_sample. Returns ------- torch.Tensor The reprojected image with the same shape as the target image. Pixel values are interpolated from the source image according to the WCS transformation with flux conservation preserved. Notes ----- This implementation uses a two-step flux conservation approach: 1. Local flux density conservation: The image and a "ones" tensor are interpolated together, and the interpolated image is divided by the interpolated ones tensor (footprint) to correct for any flux density spreading during interpolation. This is important when for pixels at the edge of the input image when mapped to the output image in case the input image only partially fills the output pixel. 2. Jacobian correction for full flux conservation: Multiply the footprint-corrected flux by the determinant of the Jacobian to handle changes in area during the reprojection The Jacobian correction can be circumvented if you set compute_jacobian=False. However, the default behavior is to include this. Areas in the target image that map outside the source image boundaries will be filled with NaNs. """ # Convenience path: compute + apply in one call. pixel_map = self.compute_pixel_map() result = self.apply_pixel_map(pixel_map, interpolation_mode) x_source, y_source = pixel_map del x_source, y_source return result
[docs] def calculate_reprojection( source_hdus: Union[ PrimaryHDU, TensorHDU, Tuple[np.ndarray, Union[WCS, Header]], Tuple[torch.Tensor, Union[WCS, Header]], List[Union[PrimaryHDU, Tuple[np.ndarray, Union[WCS, Header]]]], ], target_wcs: Union[WCS, Header], shape_out: Optional[Tuple[int, int]] = None, order: str = "nearest", device: str = None, num_threads: int = None, requires_grad: bool = False, conserve_flux: bool = True, compute_jacobian: bool = True, max_memory_mb: Optional[float] = None, chunk_safety_factor: float = 0.8, show_chunk_progress: bool = True, show_log: bool = False, ): """ Reproject an astronomical image from a source WCS to a target WCS. This high-level function provides a convenient interface for image reprojection, handling all the necessary steps: WCS extraction, tensor creation, and interpolation. It converts FITS HDU objects to the internal representation, performs the reprojection, and returns the resulting image as a NumPy array or PyTorch tensor. Parameters ---------- source_hdus : PrimaryHDU, TensorHDU, tuple, or list The source image(s) to be reprojected. Can be: - A PrimaryHDU - A TensorHDU - A tuple of (np.ndarray or torch.Tensor, WCS or Header) - A list of any of the above target_wcs : Union[WCS, Header] WCS information for the target. If a Header is passed it will be converted to WCS. shape_out: Optional[Tuple[int, int]] Shape of the resampled array. If not provided, the output shape will match the input. order : str, default 'nearest' The interpolation method to use when resampling the source image. Options: - 'nearest' : Nearest neighbor interpolation (fastest, default) - 'bilinear' : Bilinear interpolation (good balance of speed/quality) - 'bicubic' : Bicubic interpolation (high quality, slow) - 'lanczos' : Lanczos 3-lobe interpolation (highest quality, slowest) device: str, optional Device to use for computations. Defaults to GPU if available, otherwise uses CPU. num_threads: int, optional Number of threads to use on CPU. requires_grad: bool, optional If True, enables autograd for PyTorch tensors. conserve_flux: bool, optional If True, enables flux conservation through footprint calculations. By default, this is set to True. compute_jacobian: bool, optional If True, enables non-linear flux conservation through Jacobian calculation. Note that this slightly increases RAM usage. By default, this is set to True. If there is no SIP distortion, users can set this to False. max_memory_mb: Optional[float], optional Maximum memory to use in megabytes for chunked processing. If None (default), processes the entire image at once without chunking. Set this to enable memory-limited chunked processing (e.g., 1000 for 1GB limit). chunk_safety_factor: float, optional Safety factor (0-1) for chunked processing. Default 0.8 means use 80% of max_memory_mb for actual data, leaving 20% margin. Only used if max_memory_mb is set. show_chunk_progress: bool, optional Whether to log progress when using chunked processing. Default True. Only used if max_memory_mb is set. show_log: bool, optional Whether to log progress. Default True. Returns ------- numpy.ndarray or torch.Tensor The reprojected image as a numpy ndarray (default) or PyTorch tensor if requires_grad=True. Notes ----- This function automatically: - Detects and uses GPU acceleration if available - Handles byte order conversion for tensor creation - Converts data to float64 for processing - Converts Header to WCS if needed - Processes in memory-constrained chunks if max_memory_mb is specified **Chunked Processing:** When max_memory_mb is set, the reprojection is computed in blocks to stay within the specified memory limit. This is useful for: - Very large output images - Limited GPU memory - Batch processing multiple images To save the result as a FITS file, convert the tensor back to a NumPy array and create a new FITS HDU with the target WCS header. Examples -------- >>> from astropy.io import fits >>> from astropy.wcs import WCS >>> from dfreproject.reproject import calculate_reprojection >>> >>> # Open source and target images >>> source_hdu = fits.open('source_image.fits')[0] >>> target_hdu = fits.open('target_grid.fits')[0] >>> target_wcs = WCS(target_hdu.header) >>> >>> # Perform reprojection with bilinear interpolation >>> reprojected = calculate_reprojection( ... source_hdus=source_hdu, ... target_wcs=target_wcs, ... shape_out=target_hdu.data.shape, ... order='bilinear' ... ) >>> >>> # Perform chunked reprojection with 2GB memory limit >>> reprojected = calculate_reprojection( ... source_hdus=source_hdu, ... target_wcs=target_wcs, ... shape_out=(8000, 8000), ... order='bilinear', ... max_memory_mb=2000 ... ) >>> # Save as FITS >>> output_hdu = fits.PrimaryHDU(data=reprojected, header=target_hdu.header) >>> output_hdu.writeto('reprojected_image.fits', overwrite=True) """ def normalize_to_hdu(item): if isinstance(item, PrimaryHDU): if requires_grad and not isinstance(item, TensorHDU): return TensorHDU(data=item.data, header=item.header) else: return item elif isinstance(item, tuple) and len(item) == 2: data, wcs_or_header = item if isinstance(wcs_or_header, Header): header = wcs_or_header elif isinstance(wcs_or_header, WCS): header = wcs_or_header.to_header(relax=True) else: raise TypeError("Expected WCS or Header in tuple.") if requires_grad: return TensorHDU(data=data, header=header) else: return PrimaryHDU(data=data, header=header) else: raise TypeError( "Each item must be a PrimaryHDU, TensorHDU, or a (data, wcs/header) tuple." ) # Normalize source_input to a list of HDUs if isinstance(source_hdus, list): source_hdus = [normalize_to_hdu(item) for item in source_hdus] else: source_hdus = [normalize_to_hdu(source_hdus)] # Convert Header to WCS if needed if isinstance(target_wcs, Header): target_wcs = WCS(target_wcs) if not shape_out: shape_out = source_hdus[0].data.shape reprojection = Reproject( source_hdus=source_hdus, target_wcs=target_wcs, shape_out=shape_out, device=device, num_threads=num_threads, requires_grad=requires_grad, conserve_flux=conserve_flux, compute_jacobian=compute_jacobian, ) order = validate_interpolation_order(order) # Choose between chunked and non-chunked processing if max_memory_mb is not None: result = reproject_chunked( reprojection, max_memory_mb=max_memory_mb, safety_factor=chunk_safety_factor, interpolation_mode=order, show_progress=show_chunk_progress, ).squeeze(0) else: result = reprojection.interpolate_source_image(interpolation_mode=order) # Convert output format if requires_grad: result = result.cpu() else: result = result.cpu().numpy().astype(np.float32) torch.cuda.empty_cache() return result