Source code for dfreproject.tensorhdu

from astropy.io.fits import PrimaryHDU
import torch
import numpy as np



[docs] class TensorHDU(PrimaryHDU): def __init__( self, data=None, header=None, do_not_scale_image_data=False, ignore_blank=False, uint=True, scale_back=None, ): """ Construct a pytorch tensor HDU (Child class of PrimaryHDU with added tensor property). Parameters ---------- data : Pytorch tensor, array or ``astropy.io.fits.hdu.base.DELAYED``, optional The data in the HDU. header : `~astropy.io.fits.Header`, optional The header to be used (as a template). If ``header`` is `None`, a minimal header will be provided. do_not_scale_image_data : bool, optional If `True`, image data is not scaled using BSCALE/BZERO values when read. (default: False) ignore_blank : bool, optional If `True`, the BLANK header keyword will be ignored if present. Otherwise, pixels equal to this value will be replaced with NaNs. (default: False) uint : bool, optional Interpret signed integer data where ``BZERO`` is the central value and ``BSCALE == 1`` as unsigned integer data. For example, ``int16`` data with ``BZERO = 32768`` and ``BSCALE = 1`` would be treated as ``uint16`` data. (default: True) scale_back : bool, optional If `True`, when saving changes to a file that contained scaled image data, restore the data to the original type and reapply the original BSCALE/BZERO values. This could lead to loss of accuracy if scaling back to integer values after performing floating point operations on the data. Pseudo-unsigned integers are automatically rescaled unless scale_back is explicitly set to `False`. (default: None) """ self.tensor = data if data is not None and isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() super().__init__( data=data, header=header, do_not_scale_image_data=do_not_scale_image_data, uint=uint, ignore_blank=ignore_blank, scale_back=scale_back, ) @property
[docs] def tensor(self) -> torch.Tensor: """ Returns the image data as a torch.Tensor. """ return self.__dict__.get("tensor", None)
@tensor.setter def tensor(self, data): # Accept torch.Tensor or numpy array, but always store as torch.Tensor if data is not None and not isinstance(data, torch.Tensor): try: if isinstance(data, np.ndarray): data = torch.tensor(data, requires_grad = True) print("Converted numpy array to torch tensor with requires_grad=True.") else: data = torch.tensor(data, requires_grad = True) except Exception: data = torch.tensor(data, requires_grad = True) elif isinstance(data, torch.Tensor): if data.dtype != torch.float64: data = data.to(torch.float64) self.__dict__["tensor"] = data