{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# Comparison of Astropy vs Torch implementation of coordinate transformation\n", "In this notebook, we will compare the astropy wcs `pixel_to_world` transformation and our implementation. For our implementation, we've broken down what we do in `reprojection.reproject.calculate_skyCoords()` for readibility." ], "id": "66efec16355f2b58" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:36:27.235131Z", "start_time": "2025-03-03T17:36:25.422098Z" } }, "cell_type": "code", "source": [ "from astropy.io import fits\n", "from astropy.wcs import WCS\n", "import numpy as np\n", "import torch" ], "id": "96c2997c284d8c07", "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:37:02.591067Z", "start_time": "2025-03-03T17:37:02.486486Z" } }, "cell_type": "code", "source": [ "device = 'cpu'\n", "# Load your target WCS\n", "target_hdu = fits.open('./data/Atik1442426-0035_0032_light.fits')[0]\n", "target_wcs = WCS(target_hdu.header)\n", "target_shape = target_hdu.data.shape\n", "\n", "# Convert the data to native byte order before creating tensors\n", "target_data = target_hdu.data.astype(target_hdu.data.dtype.newbyteorder('='))\n", "# Now create the tensor\n", "target_image = torch.tensor(target_data, dtype=torch.float64, device=device)" ], "id": "69ae1057cf95927d", "outputs": [], "execution_count": 3 }, { "metadata": {}, "cell_type": "markdown", "source": "Now let's run the astropy results", "id": "c9aa0d594e436e44" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:37:14.505273Z", "start_time": "2025-03-03T17:37:14.269141Z" } }, "cell_type": "code", "source": [ "target_wcs_astropy = WCS(target_hdu.header)\n", "# Test pixel coordinate\n", "x_test, y_test = 200, 200\n", "# Get pixel grid for target image\n", "H, W = target_image.shape\n", "y_grid, x_grid = np.mgrid[0:H, 0:W] # Note: 0-based indexing\n", "\n", "# Convert to world coordinates (RA, Dec)\n", "ra_astropy, dec_astropy = target_wcs_astropy.wcs_pix2world(x_test, y_test, 0)\n" ], "id": "8918df65835b4412", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "35.24380952832857 59.09821844046949\n", "200.00000000000136 200.00000000000182\n" ] } ], "execution_count": 4 }, { "metadata": {}, "cell_type": "markdown", "source": "And now using our implementation", "id": "f5f473af97e5f2cd" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:38:10.956554Z", "start_time": "2025-03-03T17:38:10.945690Z" } }, "cell_type": "code", "source": [ "# Get WCS parameters\n", "CRPIX1 = target_wcs_astropy.wcs.crpix[0]\n", "CRPIX2 = target_wcs_astropy.wcs.crpix[1]\n", "CRVAL1 = target_wcs_astropy.wcs.crval[0] # Reference RA\n", "CRVAL2 = target_wcs_astropy.wcs.crval[1] # Reference Dec\n", "PC_matrix = target_wcs_astropy.wcs.get_pc() # PC Matrix\n", "CDELT = np.array(target_wcs_astropy.wcs.cdelt) # Scaling factors\n", "\n", "# Convert numpy arrays to torch tensors if needed\n", "if not isinstance(x_test, torch.Tensor):\n", " x = torch.tensor(x_test, device=device, dtype=torch.float64)\n", " y = torch.tensor(y_test, device=device, dtype=torch.float64)\n", "\n", "# Step 1: Compute Pixel Offsets - Precisely as in wcsprm::p2x\n", "u = x - (CRPIX1 - 1)\n", "v = y - (CRPIX2 - 1)\n", "\n", "# Step 2: Apply PC Matrix (Rotation) and CDELT (Scaling)\n", "CD_matrix = PC_matrix * CDELT # Construct CD Matrix\n", "CD_matrix = torch.tensor(CD_matrix, device=device, dtype=torch.float64)\n", "# Handle both scalar and array inputs\n", "if u.dim() == 0: # scalar\n", " pixel_offsets = torch.tensor([u.item(), v.item()], device=device, dtype=torch.float64)\n", " transformed = torch.matmul(CD_matrix, pixel_offsets)\n", " x_scaled, y_scaled = transformed.unbind()\n", "else: # arrays\n", " # Reshape for batch matrix multiplication if needed\n", " if u.dim() > 1:\n", " original_shape = u.shape\n", " u_flat = u.reshape(-1)\n", " v_flat = v.reshape(-1)\n", " else:\n", " u_flat = u\n", " v_flat = v\n", "\n", " # Stack coordinates for batch processing\n", " pixel_offsets = torch.stack([u_flat, v_flat], dim=1) # Shape: [N, 2]\n", "\n", " # Perform batch matrix multiplication\n", " transformed = torch.matmul(pixel_offsets, CD_matrix.T) # Shape: [N, 2]\n", " x_scaled = transformed[:, 0]\n", " y_scaled = transformed[:, 1]\n", "\n", " # Reshape back to original if needed\n", " if u.dim() > 1:\n", " x_scaled = x_scaled.reshape(original_shape)\n", " y_scaled = y_scaled.reshape(original_shape)\n", "\n", "# Step 3: Use the exact tanx2s logic from WCSLib\n", "# Compute the radial distance\n", "r = torch.sqrt(x_scaled ** 2 + y_scaled ** 2)\n", "r0 = torch.tensor(180.0 / torch.pi, device=device) # R2D from WCSLib\n", "\n", "# Apply the tanx2s function exactly as in WCSLib\n", "# Note the sign conventions\n", "phi = torch.zeros_like(r)\n", "non_zero_r = r != 0\n", "if torch.any(non_zero_r):\n", " phi[non_zero_r] = torch.rad2deg(torch.atan2(-x_scaled[non_zero_r], y_scaled[non_zero_r]))\n", "\n", "theta = torch.rad2deg(torch.atan2(r0, r))\n", "\n", "# Step 4: Now apply the sph2x (spherical to native) transform from prjx2s\n", "# First convert to radians exactly as WCSLib would\n", "phi_rad = torch.deg2rad(phi)\n", "theta_rad = torch.deg2rad(theta)\n", "ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)\n", "dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)\n", "\n", "# For TAN projection, the pole is at (0,90) in native coordinates\n", "sin_theta = torch.sin(theta_rad)\n", "cos_theta = torch.cos(theta_rad)\n", "sin_phi = torch.sin(phi_rad)\n", "cos_phi = torch.cos(phi_rad)\n", "sin_dec0 = torch.sin(dec0_rad)\n", "cos_dec0 = torch.cos(dec0_rad)\n", "\n", "# This is the exact calculation from wcslib's sphx2s function\n", "sin_dec = sin_theta * sin_dec0 + cos_theta * cos_dec0 * cos_phi\n", "dec_rad = torch.arcsin(sin_dec)\n", "\n", "# Calculate RA offset - exact formula from WCSLib\n", "y_term = cos_theta * sin_phi\n", "x_term = sin_theta * cos_dec0 - cos_theta * sin_dec0 * cos_phi\n", "ra_rad = ra0_rad + torch.atan2(-y_term, x_term)\n", "\n", "# Convert to degrees and normalize\n", "ra = torch.rad2deg(ra_rad) % 360.0\n", "dec = torch.rad2deg(dec_rad)\n" ], "id": "acb63ed3f25dd55", "outputs": [], "execution_count": 5 }, { "metadata": {}, "cell_type": "markdown", "source": "Let's compare our results with astropy now.", "id": "faaefb30c184bdee" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:38:31.943558Z", "start_time": "2025-03-03T17:38:31.939347Z" } }, "cell_type": "code", "source": [ "print(f\"Final celestial coordinates Torch: RA={ra}, Dec={dec}\")\n", "print(f\"Final celestial coordinates Astropy: RA={ra_astropy}, Dec={dec_astropy}\")" ], "id": "a4c943e04b06eec1", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final celestial coordinates Torch: RA=35.24380952587278, Dec=59.09821841758037\n", "Final celestial coordinates Astropy: RA=35.24380952832857, Dec=59.09821844046949\n" ] } ], "execution_count": 6 }, { "metadata": {}, "cell_type": "markdown", "source": [ "They are extremely close! This is well below the arcsecond precision we need.\n", "\n", "Now let's compare the world_to_pixel results. We are really doing a round trip test, so the coordinates, after the inverse operations, should be at the starting value." ], "id": "9d893478cac3c32e" }, { "metadata": { "ExecuteTime": { "end_time": "2025-03-03T17:40:36.952892Z", "start_time": "2025-03-03T17:40:36.940833Z" } }, "cell_type": "code", "source": [ "# Convert numpy arrays to torch tensors if needed\n", "if not isinstance(ra, torch.Tensor):\n", " ra = torch.tensor(ra, device=device)\n", " dec = torch.tensor(dec, device=device)\n", "\n", "# Helper functions for trigonometric calculations\n", "def atan2d(y, x):\n", " \"\"\"PyTorch implementation of WCSLib's atan2d function\"\"\"\n", " return torch.rad2deg(torch.atan2(y, x))\n", "\n", "def sincosd(angle_deg):\n", " \"\"\"PyTorch implementation of WCSLib's sincosd function\"\"\"\n", " angle_rad = torch.deg2rad(angle_deg)\n", " return torch.sin(angle_rad), torch.cos(angle_rad)\n", "\n", "# Step 1: Convert from world to native spherical coordinates\n", "# Convert to radians\n", "ra_rad = torch.deg2rad(ra)\n", "dec_rad = torch.deg2rad(dec)\n", "ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)\n", "dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)\n", "\n", "# Calculate the difference in RA\n", "delta_ra = ra_rad - ra0_rad\n", "\n", "# Calculate sine and cosine values\n", "sin_dec = torch.sin(dec_rad)\n", "cos_dec = torch.cos(dec_rad)\n", "sin_dec0 = torch.sin(dec0_rad)\n", "cos_dec0 = torch.cos(dec0_rad)\n", "sin_delta_ra = torch.sin(delta_ra)\n", "cos_delta_ra = torch.cos(delta_ra)\n", "\n", "# Calculate the native spherical coordinates using the correct sign conventions\n", "# Calculate the numerator for phi (native longitude)\n", "y_phi = -cos_dec * sin_delta_ra # Note the negative sign\n", "\n", "# Calculate the denominator for phi\n", "x_phi = sin_dec * cos_dec0 - cos_dec * sin_dec0 * cos_delta_ra\n", "\n", "# Calculate native longitude (phi)\n", "phi = atan2d(y_phi, x_phi)\n", "\n", "# Calculate native latitude (theta)\n", "theta = torch.rad2deg(torch.arcsin(sin_dec * sin_dec0 + cos_dec * cos_dec0 * cos_delta_ra))\n", "\n", "# Step 2: Apply the TAN projection (tans2x function from WCSLib)\n", "# Calculate sine and cosine of phi and theta\n", "sin_phi, cos_phi = sincosd(phi)\n", "sin_theta, cos_theta = sincosd(theta)\n", "\n", "# Check for singularity (when sin_theta is zero)\n", "eps = 1e-10\n", "if torch.any(torch.abs(sin_theta) < eps):\n", " raise ValueError(\"Singularity in tans2x: theta close to 0 degrees\")\n", "\n", "# r0 is the radius scaling factor (typically 180.0/π)\n", "r0 = torch.tensor(180.0 / torch.pi, device=device)\n", "\n", "# Calculate the scaling factor r with correct sign\n", "r = r0 * cos_theta / sin_theta\n", "\n", "# Calculate intermediate world coordinates (x_scaled, y_scaled)\n", "# With the corrected signs based on your findings\n", "x_scaled = -r * sin_phi # Note the negative sign\n", "y_scaled = r * cos_phi\n", "\n", "# Step 3: Apply the inverse of the CD matrix to get pixel offsets\n", "# First, construct the CD matrix\n", "CD_matrix = PC_matrix * CDELT\n", "CD_matrix = torch.tensor(CD_matrix, device=device)\n", "# Calculate the inverse of the CD matrix\n", "CD_inv = torch.linalg.inv(CD_matrix)\n", "\n", "# Handle batch processing for arrays\n", "if ra.dim() == 0: # scalar inputs\n", " standard_coords = torch.tensor([x_scaled.item(), y_scaled.item()], device=device, dtype=torch.float64)\n", " pixel_offsets = torch.matmul(CD_inv, standard_coords)\n", " u = pixel_offsets[0]\n", " v = pixel_offsets[1]\n", "else: # array inputs\n", " # Reshape for batch processing if needed\n", " if ra.dim() > 1:\n", " original_shape = ra.shape\n", " x_scaled_flat = x_scaled.reshape(-1)\n", " y_scaled_flat = y_scaled.reshape(-1)\n", " else:\n", " x_scaled_flat = x_scaled\n", " y_scaled_flat = y_scaled\n", "\n", " # Stack for batch matrix multiplication\n", " standard_coords = torch.stack([x_scaled_flat, y_scaled_flat], dim=1) # Shape: [N, 2]\n", "\n", " # Use batch matrix multiplication\n", " pixel_offsets = torch.matmul(standard_coords, CD_inv.T) # Shape: [N, 2]\n", " u = pixel_offsets[:, 0]\n", " v = pixel_offsets[:, 1]\n", "\n", " # Reshape back to original dimensions if needed\n", " if ra.dim() > 1:\n", " u = u.reshape(original_shape)\n", " v = v.reshape(original_shape)\n", "\n", "# Step 4: Add the reference pixel to get final pixel coordinates\n", "# Remember to add (CRPIX-1) to account for 1-based indexing in FITS/WCS\n", "x_pixel = u + (CRPIX1 - 1)\n", "y_pixel = v + (CRPIX2 - 1)\n", "\n", "\n", "print(f\"Final: x={x_pixel}, y={y_pixel}\")\n", "print(f\"Difference in x: {x_pixel - x_test}\")\n", "print(f\"Difference in y: {y_pixel - y_test}\")" ], "id": "358977445fd6550b", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Final: x=200.00000000004002, y=200.00000000003547\n", "Difference in x: 4.001776687800884e-11\n", "Difference in y: 3.54702933691442e-11\n" ] } ], "execution_count": 9 }, { "metadata": {}, "cell_type": "markdown", "source": "So this is well below what we need :)", "id": "b96ec6d3156b5e42" }, { "metadata": {}, "cell_type": "code", "source": "", "id": "f05e82b28245978b", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }