Source code for dfreproject.sip

from typing import Tuple

import astropy
import torch


[docs] def get_sip_coeffs( wcs: astropy.wcs.WCS, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Extract SIP polynomial coefficients from a WCS object. Parameters: ----------- wcs : astropy.wcs.WCS WCS object potentially containing SIP distortion Returns: -------- dict: Dictionary containing SIP coefficient matrices A, B, AP, BP and orders """ sip_coeffs = {} # Check if SIP distortion is present sip = getattr(wcs, "sip", None) if sip is None: return None # Extract the SIP matrices sip_coeffs["a_order"] = sip.a_order sip_coeffs["b_order"] = sip.b_order sip_coeffs["a"] = sip.a sip_coeffs["b"] = sip.b # Check for inverse coefficients if hasattr(sip, "ap_order") and sip.ap_order > 0: sip_coeffs["ap_order"] = sip.ap_order sip_coeffs["ap"] = sip.ap else: sip_coeffs["ap_order"] = 0 if hasattr(sip, "bp_order") and sip.bp_order > 0: sip_coeffs["bp_order"] = sip.bp_order sip_coeffs["bp"] = sip.bp else: sip_coeffs["bp_order"] = 0 return sip_coeffs
[docs] def apply_sip_distortion( u: torch.Tensor, v: torch.Tensor, sip_coeffs: Tuple, device: str = "cpu" ): """ Apply SIP distortion to intermediate pixel coordinates. Parameters: ----------- u, v : torch.Tensor Intermediate pixel coordinates (before distortion) sip_coeffs : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] SIP coefficient matrices device : torch.device, optional Device to place tensors on Returns: -------- tuple: (u', v') distorted coordinates """ if sip_coeffs is None: return u, v # Convert to tensors if needed if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=device) v = torch.tensor(v, device=device) # Get the SIP orders a_order = sip_coeffs["a_order"] # Convert coefficient matrices to tensors a_matrix = torch.tensor(sip_coeffs["a"], device=device) b_matrix = torch.tensor(sip_coeffs["b"], device=device) # Initialize correction terms f_u = torch.zeros_like(u) f_v = torch.zeros_like(v) # For array inputs, reshape to make computation easier orig_shape = u.shape u_flat = u.reshape(-1) if u.dim() > 0 else u.unsqueeze(0) v_flat = v.reshape(-1) if v.dim() > 0 else v.unsqueeze(0) f_u_flat = f_u.reshape(-1) if f_u.dim() > 0 else f_u.unsqueeze(0) f_v_flat = f_v.reshape(-1) if f_v.dim() > 0 else f_v.unsqueeze(0) # Apply the polynomial distortion for i in range(a_order + 1): for j in range(a_order + 1 - i): if i == 0 and j == 0: continue # Skip the 0,0 term # Compute u^i * v^j for all points pow_term = (u_flat**i) * (v_flat**j) # Apply coefficient f_u_flat += a_matrix[i, j] * pow_term f_v_flat += b_matrix[i, j] * pow_term # Reshape back to original shape if needed if u.dim() > 0: f_u = f_u_flat.reshape(orig_shape) f_v = f_v_flat.reshape(orig_shape) else: f_u = f_u_flat[0] f_v = f_v_flat[0] # Add the distortion terms to get the corrected coordinates u_corrected = u + f_u v_corrected = v + f_v return u_corrected, v_corrected
[docs] def apply_inverse_sip_distortion( u: torch.Tensor, v: torch.Tensor, sip_coeffs: Tuple, device: str = "cpu" ): """ Apply inverse SIP distortion to go from distorted to intermediate coordinates. Parameters: ----------- u, v : torch.Tensor Distorted coordinates sip_coeffs : Tuple SIP coefficient matrices device : torch.device, optional Device to place tensors on Returns: -------- tuple: (u', v') undistorted coordinates """ if sip_coeffs is None: return u, v # Check if inverse coefficients are available if sip_coeffs["ap_order"] == 0 or sip_coeffs["bp_order"] == 0: # Use iterative method if inverse coefficients aren't available return iterative_inverse_sip(u, v, sip_coeffs, device) # Convert to tensors if needed if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=device) v = torch.tensor(v, device=device) # Get the SIP orders ap_order = sip_coeffs["ap_order"] # Convert coefficient matrices to tensors ap_matrix = torch.tensor(sip_coeffs["ap"], device=device) bp_matrix = torch.tensor(sip_coeffs["bp"], device=device) # Initialize correction terms f_u = torch.zeros_like(u) f_v = torch.zeros_like(v) # For array inputs, reshape to make computation easier orig_shape = u.shape u_flat = u.reshape(-1) if u.dim() > 0 else u.unsqueeze(0) v_flat = v.reshape(-1) if v.dim() > 0 else v.unsqueeze(0) f_u_flat = f_u.reshape(-1) if f_u.dim() > 0 else f_u.unsqueeze(0) f_v_flat = f_v.reshape(-1) if f_v.dim() > 0 else f_v.unsqueeze(0) # Apply the polynomial correction for i in range(ap_order + 1): for j in range(ap_order + 1 - i): if i == 0 and j == 0: continue # Skip the 0,0 term # Compute u^i * v^j for all points pow_term = (u_flat**i) * (v_flat**j) # Apply coefficient f_u_flat += ap_matrix[i, j] * pow_term f_v_flat += bp_matrix[i, j] * pow_term # Reshape back to original shape if needed if u.dim() > 0: f_u = f_u_flat.reshape(orig_shape) f_v = f_v_flat.reshape(orig_shape) else: f_u = f_u_flat[0] f_v = f_v_flat[0] # Add the correction terms to get the undistorted coordinates u_corrected = u + f_u v_corrected = v + f_v return u_corrected, v_corrected
[docs] def iterative_inverse_sip( u: torch.Tensor, v: torch.Tensor, sip_coeffs: Tuple, device: str = "cpu", max_iter: int = 20, tol: float = 1e-8, ): """ Iteratively solve for undistorted coordinates when inverse SIP coefficients are not available. Parameters: ----------- u, v : torch.Tensor Distorted coordinates sip_coeffs : Tuple SIP coefficient matrices device : torch.device, optional Device to place tensors on max_iter : int, optional Maximum number of iterations tol : float, optional Convergence tolerance Returns: -------- tuple: (u', v') undistorted coordinates """ # Convert to tensors if needed if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=device) v = torch.tensor(v, device=device) # Initial guess: undistorted = distorted u_undist = u.clone() v_undist = v.clone() for i in range(max_iter): # Apply forward SIP to get predicted distorted coordinates u_pred, v_pred = apply_sip_distortion(u_undist, v_undist, sip_coeffs, device) # Compute error u_error = u - u_pred v_error = v - v_pred # Check convergence max_error = torch.max( torch.abs(torch.cat([u_error.flatten(), v_error.flatten()])) ) if max_error < tol: break # Update undistorted coordinates u_undist = u_undist + u_error v_undist = v_undist + v_error return u_undist, v_undist