from typing import Tuple
import astropy
import torch
[docs]
def get_sip_coeffs(
wcs: astropy.wcs.WCS,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Extract SIP polynomial coefficients from a WCS object.
Parameters:
-----------
wcs : astropy.wcs.WCS
WCS object potentially containing SIP distortion.
Returns:
--------
dict:
Dictionary containing SIP coefficient matrices A, B, AP, BP and orders.
"""
sip_coeffs = {}
# Check if SIP distortion is present
sip = getattr(wcs, "sip", None)
if sip is None:
return None
# Extract the SIP matrices
sip_coeffs["a_order"] = sip.a_order
sip_coeffs["b_order"] = sip.b_order
sip_coeffs["a"] = sip.a
sip_coeffs["b"] = sip.b
# Check for inverse coefficients
if hasattr(sip, "ap_order") and sip.ap_order > 0:
sip_coeffs["ap_order"] = sip.ap_order
sip_coeffs["ap"] = sip.ap
else:
sip_coeffs["ap_order"] = 0
if hasattr(sip, "bp_order") and sip.bp_order > 0:
sip_coeffs["bp_order"] = sip.bp_order
sip_coeffs["bp"] = sip.bp
else:
sip_coeffs["bp_order"] = 0
return sip_coeffs
[docs]
def apply_sip_distortion(
u: torch.Tensor, v: torch.Tensor, sip_coeffs: Tuple, device: str = "cpu"
):
"""
Apply SIP distortion to intermediate pixel coordinates.
Parameters:
-----------
u, v : torch.Tensor
Intermediate pixel coordinates (before distortion).
sip_coeffs : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
SIP coefficient matrices.
device : torch.device, optional
Device to place tensors on.
Returns:
--------
tuple:
(u', v') distorted coordinates.
# Convert to tensors if needed
"""
if sip_coeffs is None:
return u, v
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=device)
v = torch.tensor(v, device=device)
a_order = sip_coeffs["a_order"]
a_matrix = torch.tensor(sip_coeffs["a"], dtype=torch.float64, device=device)
b_matrix = torch.tensor(sip_coeffs["b"], dtype=torch.float64, device=device)
# Precompute powers of u and v
# Save memory by computing and using only needed powers
u_powers = [torch.ones_like(u)]
v_powers = [torch.ones_like(v)]
for n in range(1, a_order + 1):
u_powers.append(u_powers[-1] * u)
v_powers.append(v_powers[-1] * v)
# Apply polynomial distortion
f_u = torch.zeros_like(u)
f_v = torch.zeros_like(v)
for i in range(a_order + 1):
for j in range(a_order + 1 - i):
if i == 0 and j == 0:
continue # Skip the (0,0) term
term = u_powers[i] * v_powers[j]
f_u = f_u + a_matrix[i, j] * term
f_v = f_v + b_matrix[i, j] * term
# Return corrected coordinates
return u + f_u, v + f_v
[docs]
def apply_inverse_sip_distortion(
u: torch.Tensor, v: torch.Tensor, sip_coeffs: Tuple, device: str = "cpu"
):
"""
Apply inverse SIP distortion to go from distorted to intermediate coordinates.
Parameters:
-----------
u, v : torch.Tensor
Distorted coordinates.
sip_coeffs : Tuple
SIP coefficient matrices.
device : torch.device, optional
Device to place tensors on.
Returns:
--------
tuple:
(u', v') undistorted coordinates.
"""
if sip_coeffs is None:
return u, v
# Convert to tensors if needed
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=device)
v = torch.tensor(v, device=device)
# Use iterative method if inverse SIP coefficients are not defined
if sip_coeffs["ap_order"] == 0 or sip_coeffs["bp_order"] == 0:
return iterative_inverse_sip(u, v, sip_coeffs, device)
# Ensure inputs are tensors on correct device and dtype
u = u.to(dtype=torch.float32, device=device)
v = v.to(dtype=torch.float32, device=device)
ap_order = sip_coeffs["ap_order"]
ap_matrix = torch.tensor(sip_coeffs["ap"], dtype=torch.float32, device=device)
bp_matrix = torch.tensor(sip_coeffs["bp"], dtype=torch.float32, device=device)
del sip_coeffs
# Precompute powers of u and v to avoid repeated allocation
u_powers = [torch.ones_like(u)]
v_powers = [torch.ones_like(v)]
for n in range(1, ap_order + 1):
u_powers.append(u_powers[-1] * u)
v_powers.append(v_powers[-1] * v)
# Initialize correction terms
f_u = torch.zeros_like(u)
f_v = torch.zeros_like(v)
for i in range(ap_order + 1):
for j in range(ap_order + 1 - i):
if i == 0 and j == 0:
continue
term = u_powers[i] * v_powers[j]
f_u = f_u + ap_matrix[i, j] * term
f_v = f_v + bp_matrix[i, j] * term
del u_powers, v_powers, ap_matrix, bp_matrix
return u + f_u, v + f_v
[docs]
def iterative_inverse_sip(
u: torch.Tensor,
v: torch.Tensor,
sip_coeffs: Tuple,
device: str = "cpu",
max_iter: int = 20,
tol: float = 1e-8,
):
"""
Iteratively solve for undistorted coordinates when inverse SIP coefficients
are not available.
Parameters:
-----------
u, v : torch.Tensor
Distorted coordinates.
sip_coeffs : Tuple
SIP coefficient matrices.
device : torch.device, optional
Device to place tensors on.
max_iter : int, optional
Maximum number of iterations.
tol : float, optional
Convergence tolerance.
Returns:
--------
tuple:
(u', v') undistorted coordinates.
"""
# Convert to tensors if needed
if not isinstance(u, torch.Tensor):
u = torch.tensor(u, device=device)
v = torch.tensor(v, device=device)
# Initial guess: undistorted = distorted
u_undist = u.clone()
v_undist = v.clone()
for i in range(max_iter):
# Apply forward SIP to get predicted distorted coordinates
u_pred, v_pred = apply_sip_distortion(u_undist, v_undist, sip_coeffs, device)
# Compute error
u_error = u - u_pred
v_error = v - v_pred
# Check convergence
max_error = torch.max(
torch.abs(torch.cat([u_error.flatten(), v_error.flatten()]))
)
if max_error < tol:
break
# Update undistorted coordinates
u_undist = u_undist + u_error
v_undist = v_undist + v_error
return u_undist, v_undist