Comparison of Astropy vs Torch implementation of coordinate transformation
In this notebook, we will compare the astropy wcs pixel_to_world transformation and our implementation. For our implementation, we’ve broken down what we do in reprojection.reproject.calculate_skyCoords() for readibility.
[1]:
from astropy.io import fits
from astropy.wcs import WCS
import numpy as np
import torch
[3]:
device = 'cpu'
# Load your target WCS
target_hdu = fits.open('./data/Atik1442426-0035_0032_light.fits')[0]
target_wcs = WCS(target_hdu.header)
target_shape = target_hdu.data.shape
# Convert the data to native byte order before creating tensors
target_data = target_hdu.data.astype(target_hdu.data.dtype.newbyteorder('='))
# Now create the tensor
target_image = torch.tensor(target_data, dtype=torch.float64, device=device)
Now let’s run the astropy results
[4]:
target_wcs_astropy = WCS(target_hdu.header)
# Test pixel coordinate
x_test, y_test = 200, 200
# Get pixel grid for target image
H, W = target_image.shape
y_grid, x_grid = np.mgrid[0:H, 0:W] # Note: 0-based indexing
# Convert to world coordinates (RA, Dec)
ra_astropy, dec_astropy = target_wcs_astropy.wcs_pix2world(x_test, y_test, 0)
35.24380952832857 59.09821844046949
200.00000000000136 200.00000000000182
And now using our implementation
[5]:
# Get WCS parameters
CRPIX1 = target_wcs_astropy.wcs.crpix[0]
CRPIX2 = target_wcs_astropy.wcs.crpix[1]
CRVAL1 = target_wcs_astropy.wcs.crval[0] # Reference RA
CRVAL2 = target_wcs_astropy.wcs.crval[1] # Reference Dec
PC_matrix = target_wcs_astropy.wcs.get_pc() # PC Matrix
CDELT = np.array(target_wcs_astropy.wcs.cdelt) # Scaling factors
# Convert numpy arrays to torch tensors if needed
if not isinstance(x_test, torch.Tensor):
x = torch.tensor(x_test, device=device, dtype=torch.float64)
y = torch.tensor(y_test, device=device, dtype=torch.float64)
# Step 1: Compute Pixel Offsets - Precisely as in wcsprm::p2x
u = x - (CRPIX1 - 1)
v = y - (CRPIX2 - 1)
# Step 2: Apply PC Matrix (Rotation) and CDELT (Scaling)
CD_matrix = PC_matrix * CDELT # Construct CD Matrix
CD_matrix = torch.tensor(CD_matrix, device=device, dtype=torch.float64)
# Handle both scalar and array inputs
if u.dim() == 0: # scalar
pixel_offsets = torch.tensor([u.item(), v.item()], device=device, dtype=torch.float64)
transformed = torch.matmul(CD_matrix, pixel_offsets)
x_scaled, y_scaled = transformed.unbind()
else: # arrays
# Reshape for batch matrix multiplication if needed
if u.dim() > 1:
original_shape = u.shape
u_flat = u.reshape(-1)
v_flat = v.reshape(-1)
else:
u_flat = u
v_flat = v
# Stack coordinates for batch processing
pixel_offsets = torch.stack([u_flat, v_flat], dim=1) # Shape: [N, 2]
# Perform batch matrix multiplication
transformed = torch.matmul(pixel_offsets, CD_matrix.T) # Shape: [N, 2]
x_scaled = transformed[:, 0]
y_scaled = transformed[:, 1]
# Reshape back to original if needed
if u.dim() > 1:
x_scaled = x_scaled.reshape(original_shape)
y_scaled = y_scaled.reshape(original_shape)
# Step 3: Use the exact tanx2s logic from WCSLib
# Compute the radial distance
r = torch.sqrt(x_scaled ** 2 + y_scaled ** 2)
r0 = torch.tensor(180.0 / torch.pi, device=device) # R2D from WCSLib
# Apply the tanx2s function exactly as in WCSLib
# Note the sign conventions
phi = torch.zeros_like(r)
non_zero_r = r != 0
if torch.any(non_zero_r):
phi[non_zero_r] = torch.rad2deg(torch.atan2(-x_scaled[non_zero_r], y_scaled[non_zero_r]))
theta = torch.rad2deg(torch.atan2(r0, r))
# Step 4: Now apply the sph2x (spherical to native) transform from prjx2s
# First convert to radians exactly as WCSLib would
phi_rad = torch.deg2rad(phi)
theta_rad = torch.deg2rad(theta)
ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)
dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)
# For TAN projection, the pole is at (0,90) in native coordinates
sin_theta = torch.sin(theta_rad)
cos_theta = torch.cos(theta_rad)
sin_phi = torch.sin(phi_rad)
cos_phi = torch.cos(phi_rad)
sin_dec0 = torch.sin(dec0_rad)
cos_dec0 = torch.cos(dec0_rad)
# This is the exact calculation from wcslib's sphx2s function
sin_dec = sin_theta * sin_dec0 + cos_theta * cos_dec0 * cos_phi
dec_rad = torch.arcsin(sin_dec)
# Calculate RA offset - exact formula from WCSLib
y_term = cos_theta * sin_phi
x_term = sin_theta * cos_dec0 - cos_theta * sin_dec0 * cos_phi
ra_rad = ra0_rad + torch.atan2(-y_term, x_term)
# Convert to degrees and normalize
ra = torch.rad2deg(ra_rad) % 360.0
dec = torch.rad2deg(dec_rad)
Let’s compare our results with astropy now.
[6]:
print(f"Final celestial coordinates Torch: RA={ra}, Dec={dec}")
print(f"Final celestial coordinates Astropy: RA={ra_astropy}, Dec={dec_astropy}")
Final celestial coordinates Torch: RA=35.24380952587278, Dec=59.09821841758037
Final celestial coordinates Astropy: RA=35.24380952832857, Dec=59.09821844046949
They are extremely close! This is well below the arcsecond precision we need.
Now let’s compare the world_to_pixel results. We are really doing a round trip test, so the coordinates, after the inverse operations, should be at the starting value.
[9]:
# Convert numpy arrays to torch tensors if needed
if not isinstance(ra, torch.Tensor):
ra = torch.tensor(ra, device=device)
dec = torch.tensor(dec, device=device)
# Helper functions for trigonometric calculations
def atan2d(y, x):
"""PyTorch implementation of WCSLib's atan2d function"""
return torch.rad2deg(torch.atan2(y, x))
def sincosd(angle_deg):
"""PyTorch implementation of WCSLib's sincosd function"""
angle_rad = torch.deg2rad(angle_deg)
return torch.sin(angle_rad), torch.cos(angle_rad)
# Step 1: Convert from world to native spherical coordinates
# Convert to radians
ra_rad = torch.deg2rad(ra)
dec_rad = torch.deg2rad(dec)
ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)
dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)
# Calculate the difference in RA
delta_ra = ra_rad - ra0_rad
# Calculate sine and cosine values
sin_dec = torch.sin(dec_rad)
cos_dec = torch.cos(dec_rad)
sin_dec0 = torch.sin(dec0_rad)
cos_dec0 = torch.cos(dec0_rad)
sin_delta_ra = torch.sin(delta_ra)
cos_delta_ra = torch.cos(delta_ra)
# Calculate the native spherical coordinates using the correct sign conventions
# Calculate the numerator for phi (native longitude)
y_phi = -cos_dec * sin_delta_ra # Note the negative sign
# Calculate the denominator for phi
x_phi = sin_dec * cos_dec0 - cos_dec * sin_dec0 * cos_delta_ra
# Calculate native longitude (phi)
phi = atan2d(y_phi, x_phi)
# Calculate native latitude (theta)
theta = torch.rad2deg(torch.arcsin(sin_dec * sin_dec0 + cos_dec * cos_dec0 * cos_delta_ra))
# Step 2: Apply the TAN projection (tans2x function from WCSLib)
# Calculate sine and cosine of phi and theta
sin_phi, cos_phi = sincosd(phi)
sin_theta, cos_theta = sincosd(theta)
# Check for singularity (when sin_theta is zero)
eps = 1e-10
if torch.any(torch.abs(sin_theta) < eps):
raise ValueError("Singularity in tans2x: theta close to 0 degrees")
# r0 is the radius scaling factor (typically 180.0/π)
r0 = torch.tensor(180.0 / torch.pi, device=device)
# Calculate the scaling factor r with correct sign
r = r0 * cos_theta / sin_theta
# Calculate intermediate world coordinates (x_scaled, y_scaled)
# With the corrected signs based on your findings
x_scaled = -r * sin_phi # Note the negative sign
y_scaled = r * cos_phi
# Step 3: Apply the inverse of the CD matrix to get pixel offsets
# First, construct the CD matrix
CD_matrix = PC_matrix * CDELT
CD_matrix = torch.tensor(CD_matrix, device=device)
# Calculate the inverse of the CD matrix
CD_inv = torch.linalg.inv(CD_matrix)
# Handle batch processing for arrays
if ra.dim() == 0: # scalar inputs
standard_coords = torch.tensor([x_scaled.item(), y_scaled.item()], device=device, dtype=torch.float64)
pixel_offsets = torch.matmul(CD_inv, standard_coords)
u = pixel_offsets[0]
v = pixel_offsets[1]
else: # array inputs
# Reshape for batch processing if needed
if ra.dim() > 1:
original_shape = ra.shape
x_scaled_flat = x_scaled.reshape(-1)
y_scaled_flat = y_scaled.reshape(-1)
else:
x_scaled_flat = x_scaled
y_scaled_flat = y_scaled
# Stack for batch matrix multiplication
standard_coords = torch.stack([x_scaled_flat, y_scaled_flat], dim=1) # Shape: [N, 2]
# Use batch matrix multiplication
pixel_offsets = torch.matmul(standard_coords, CD_inv.T) # Shape: [N, 2]
u = pixel_offsets[:, 0]
v = pixel_offsets[:, 1]
# Reshape back to original dimensions if needed
if ra.dim() > 1:
u = u.reshape(original_shape)
v = v.reshape(original_shape)
# Step 4: Add the reference pixel to get final pixel coordinates
# Remember to add (CRPIX-1) to account for 1-based indexing in FITS/WCS
x_pixel = u + (CRPIX1 - 1)
y_pixel = v + (CRPIX2 - 1)
print(f"Final: x={x_pixel}, y={y_pixel}")
print(f"Difference in x: {x_pixel - x_test}")
print(f"Difference in y: {y_pixel - y_test}")
Final: x=200.00000000004002, y=200.00000000003547
Difference in x: 4.001776687800884e-11
Difference in y: 3.54702933691442e-11
So this is well below what we need :)
[ ]: