# FPFS shear estimator
# Copyright 20210905 Xiangchong Li.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# python lib
import jax
import logging
import numpy as np
import jax.numpy as jnp
from . import imgutil
from functools import partial
logging.basicConfig(
format="%(asctime)s %(message)s",
datefmt="%Y/%m/%d %H:%M:%S --- ",
level=logging.INFO,
)
[docs]
def results_coords(dd):
coords = np.rec.fromarrays(
dd.T,
dtype=[("fpfs_y", "i4"), ("fpfs_x", "i4")],
)
return coords
[docs]
class measure_base:
"""A base class for measurement, which is extended to measure_source
and measure_noise_cov
Args:
psf_data (ndarray): an average PSF image used to initialize the task
pix_scale (float): pixel scale in arcsec
sigma_arcsec (float): Shapelet kernel size
sigma_detect (float): detection kernel size
nnord (int): the highest order of Shapelets radial
components [default: 4]
"""
_DefaultName = "measure_base"
def __init__(
self,
psf_data,
pix_scale,
sigma_arcsec,
sigma_detect=None,
nnord=4,
):
if sigma_arcsec <= 0.0 or sigma_arcsec > 5.0:
raise ValueError("sigma_arcsec should be positive and less than 5 arcsec")
self.ngrid = psf_data.shape[0]
self.nnord = nnord
if sigma_detect is None:
sigma_detect = sigma_arcsec
# Preparing PSF
psf_data = jnp.array(psf_data, dtype="<f8")
self.psf_fourier = jnp.fft.fftshift(jnp.fft.fft2(psf_data))
self.psf_pow = imgutil.get_fourier_pow_fft(psf_data)
# A few import scales
self.pix_scale = pix_scale
self._dk = 2.0 * jnp.pi / self.ngrid # assuming pixel scale is 1
# the following two assumes pixel_scale = 1
self.sigmaf = float(self.pix_scale / sigma_arcsec)
self.sigmaf_det = float(self.pix_scale / sigma_detect)
sigma_pixf = self.sigmaf / self._dk
sigma_pixf_det = self.sigmaf_det / self._dk
logging.info("Order of the shear estimator: nnord=%d" % self.nnord)
logging.info(
"Shapelet kernel in configuration space: sigma= %.4f arcsec"
% (sigma_arcsec)
)
logging.info(
"Detection kernel in configuration space: sigma= %.4f arcsec"
% (sigma_detect)
)
# effective nyquest wave number
self.klim_pix = imgutil.get_klim(
psf_array=self.psf_pow,
sigma=(sigma_pixf + sigma_pixf_det) / 2.0 / jnp.sqrt(2.0),
thres=1e-20,
) # in pixel units
self.klim_pix = min(self.klim_pix, self.ngrid // 2 - 1)
self.klim = float(self.klim_pix * self._dk)
logging.info("Maximum |k| is %.3f" % (self.klim))
self._indx = jnp.arange(
self.ngrid // 2 - self.klim_pix,
self.ngrid // 2 + self.klim_pix + 1,
)
self._indy = self._indx[:, None]
self._ind2d = jnp.ix_(self._indx, self._indx)
return
[docs]
@partial(jax.jit, static_argnames=["self"])
def deconvolve(self, data, prder=1.0, frder=1.0):
"""Deconvolves input data with the PSF or PSF power
Args:
data (ndarray):
galaxy power or galaxy Fourier transfer, origin is set to
[ngrid//2,ngrid//2]
prder (float):
deconvlove order of PSF FT power
frder (float):
deconvlove order of PSF FT
Returns:
out (ndarray):
Deconvolved galaxy power [truncated at klim]
"""
out = jnp.zeros(data.shape, dtype="complex128")
out2 = out.at[self._ind2d].set(
data[self._ind2d]
/ self.psf_pow[self._ind2d] ** prder
/ self.psf_fourier[self._ind2d] ** frder
)
return out2
[docs]
class measure_noise_cov(measure_base):
"""A class to measure FPFS noise covariance of basis modes
Args:
psf_data (ndarray): an average PSF image used to initialize the task
pix_scale (float): pixel scale in arcsec
sigma_arcsec (float): Shapelet kernel size
sigma_detect (float): detection kernel size
nnord (int): the highest order of Shapelets radial
components [default: 4]
"""
_DefaultName = "measure_noise_cov"
def __init__(
self,
psf_data,
pix_scale,
sigma_arcsec,
sigma_detect=None,
nnord=4,
):
super().__init__(
psf_data=psf_data,
sigma_arcsec=sigma_arcsec,
nnord=nnord,
pix_scale=pix_scale,
sigma_detect=sigma_detect,
)
bfunc, bnames = imgutil.fpfs_bases(
self.ngrid,
nnord,
self.sigmaf,
self.sigmaf_det,
self.klim,
)
self.bfunc = bfunc[:, self._indy, self._indx]
self.bnames = bnames
return
[docs]
def measure(self, noise_pf):
"""Estimate covariance of measurement error in impt form
Args:
noise_pf (ndarray): power spectrum (assuming homogeneous) of noise
Return:
cov_matrix (ndarray): covariance matrix of FPFS basis modes
"""
noise_pf = jnp.array(noise_pf, dtype="<f8")
noise_pf_deconv = self.deconvolve(noise_pf, prder=1, frder=0)
cov_matrix = (
jnp.real(
jnp.tensordot(
self.bfunc * noise_pf_deconv[None, self._indy, self._indx],
jnp.conjugate(self.bfunc),
axes=((1, 2), (1, 2)),
)
)
/ self.pix_scale**4.0
)
return cov_matrix
[docs]
class measure_source(measure_base):
"""A class to measure FPFS shapelet mode estimation
Args:
psf_data (ndarray): an average PSF image used to initialize the task
pix_scale (float): pixel scale in arcsec
sigma_arcsec (float): Shapelet kernel size
sigma_detect (float): detection kernel size
nnord (int): the highest order of Shapelets radial components
[default: 4]
"""
_DefaultName = "measure_source"
def __init__(
self,
psf_data,
pix_scale,
sigma_arcsec,
sigma_detect=None,
nnord=4,
):
super().__init__(
psf_data=psf_data,
sigma_arcsec=sigma_arcsec,
nnord=nnord,
pix_scale=pix_scale,
sigma_detect=sigma_detect,
)
# Preparing shapelet basis
# nm = n*(nnord+1)+m
# nnord is the maximum 'n' the code calculates
if nnord == 4:
# This setup is for shear response only
# Only uses M00, M20, M22 (real and img) and M40, M42
self._indM = np.array([0, 10, 12, 20, 22])[:, None, None]
elif nnord == 6:
# This setup is able to derive kappa response and shear response
# Only uses M00, M20, M22 (real and img), M40, M42(real and img), M60
self._indM = np.array([0, 14, 16, 28, 30, 42])[:, None, None]
else:
raise ValueError(
"only support for nnord= 4 or nnord=6, but your input\
is nnord=%d"
% nnord
)
chi = imgutil.shapelets2d(
self.ngrid,
nnord,
self.sigmaf,
self.klim,
)[self._indM, self._indy, self._indx]
psi = imgutil.detlets2d(
self.ngrid,
self.sigmaf_det,
self.klim,
)[:, :, self._indy, self._indx]
self.prepare_chi(chi)
self.prepare_psi(psi)
del chi, psi
return
[docs]
def detect_sources(
self,
img_data,
psf_data,
thres,
thres2,
bound=None,
):
"""Returns the coordinates of detected sources
Args:
img_data (ndarray): observed image
psf_data (ndarray): PSF image [must be well-centered]
thres (float): detection threshold
thres2 (float): peak identification difference threshold
bound (int): remove sources at boundary
Returns:
coords (ndarray): peak values and the shear responses
"""
if not isinstance(thres, (int, float)):
raise ValueError("thres must be float, but now got %s" % type(thres))
if not isinstance(thres2, (int, float)):
raise ValueError("thres2 must be float, but now got %s" % type(thres))
if not thres > 0.0:
raise ValueError("detection threshold should be positive")
if not thres2 <= 0.0:
raise ValueError("difference threshold should be non-positive")
psf_data = jnp.array(psf_data, dtype="<f8")
assert (
img_data.shape == psf_data.shape
), "image and PSF should have the same\
shape. Please do padding before using this function."
img_conv = imgutil.convolve2gausspsf(
img_data,
psf_data,
self.sigmaf,
self.klim,
)
img_conv_det = imgutil.convolve2gausspsf(
img_data,
psf_data,
self.sigmaf_det,
self.klim,
)
if bound is None:
bound = self.ngrid // 2 + 5
dd = imgutil.find_peaks(img_conv, img_conv_det, thres, thres2, bound).T
return dd
[docs]
def prepare_chi(self, chi):
"""Prepares the basis to estimate shapelet modes
Args:
chi (ndarray): 2d shapelet basis
"""
out = []
if self.nnord == 4:
out.append(chi.real[0]) # x00
out.append(chi.real[1]) # x20
out.append(chi.real[2]) # x22c
out.append(chi.imag[2]) # x22s
out.append(chi.real[3]) # x40
out.append(chi.real[4]) # x42c
out.append(chi.imag[4]) # x42s
self.chi_types = [
("fpfs_M00", "<f8"),
("fpfs_M20", "<f8"),
("fpfs_M22c", "<f8"),
("fpfs_M22s", "<f8"),
("fpfs_M40", "<f8"),
("fpfs_M42c", "<f8"),
("fpfs_M42s", "<f8"),
]
elif self.nnord == 6:
out.append(chi.real[0]) # x00
out.append(chi.real[1]) # x20
out.append(chi.real[2]) # x22c
out.append(chi.imag[2]) # x22s
out.append(chi.real[3]) # x40
out.append(chi.real[4]) # x42c
out.append(chi.imag[4]) # x42s
out.append(chi.real[5]) # x60
self.chi_types = [
("fpfs_M00", "<f8"),
("fpfs_M20", "<f8"),
("fpfs_M22c", "<f8"),
("fpfs_M22s", "<f8"),
("fpfs_M40", "<f8"),
("fpfs_M42c", "<f8"),
("fpfs_M42s", "<f8"),
("fpfs_M60", "<f8"),
]
else:
raise ValueError("only support for nnord=4 or nnord=6")
assert len(out) == len(self.chi_types)
out = jnp.stack(out)
self.chi = out
return
[docs]
def prepare_psi(self, psi):
"""Prepares the basis to estimate detection modes
Args:
psi (ndarray): 2d detection basis
"""
self.psi_types = []
out = []
for _ in range(8):
out.append(psi[_, 0]) # ps_i
self.psi_types.append(("fpfs_v%d" % _, "<f8"))
for j in [1, 2]:
for i in range(8):
out.append(psi[i, j]) # ps_i;j
self.psi_types.append(("fpfs_v%dr%d" % (i, j), "<f8"))
out = jnp.stack(out)
assert len(out) == len(self.psi_types)
self.psi = out
return
@partial(jax.jit, static_argnames=["self"])
def _itransform_chi(self, data):
"""Projects image onto shapelet basis vectors
Args:
data (ndarray): image to transfer
Returns:
out (ndarray): projection in shapelet space
"""
# Here we divide by self.pix_scale**2. since pixel values are flux in
# pixel (in unit of nano Jy for HSC). After dividing pix_scale**2., in
# units of (nano Jy/ arcsec^2), dk^2 has unit (1/ arcsec^2)
# Correspondingly, covariances are divided by self.pix_scale**4.
out = (
jnp.sum(
data[None, self._indy, self._indx] * self.chi,
axis=(1, 2),
).real
/ self.pix_scale**2.0
)
return out
@partial(jax.jit, static_argnames=["self"])
def _itransform_psi(self, data):
"""Projects image onto shapelet basis vectors
Args:
data (ndarray): image to transfer
Returns:
out (ndarray): projection in shapelet space
"""
# Here we divide by self.pix_scale**2. since pixel values are flux in
# pixel (in unit of nano Jy for HSC). After dividing pix_scale**2., in
# units of (nano Jy/ arcsec^2), dk^2 has unit (1/ arcsec^2)
# Correspondingly, covariances are divided by self.pix_scale**4.
# chivatives/Moments
out = (
jnp.sum(
data[None, self._indy, self._indx] * self.psi,
axis=(1, 2),
).real
/ self.pix_scale**2.0
)
return out
[docs]
def measure(self, exposure, coords=None):
"""Measures the FPFS moments
Args:
exposure (ndarray): galaxy image
psf_fourier (ndarray): PSF's Fourier transform
Returns:
out (ndarray): FPFS moments
"""
if coords is None:
coords = jnp.array(exposure.shape) // 2
coords = jnp.atleast_2d(coords.T).T
func = lambda xi: self.measure_coord(xi, jnp.array(exposure))
return jax.lax.map(func, coords)
[docs]
@partial(jax.jit, static_argnames=["self"])
def measure_coord(self, cc, image):
"""Measures the FPFS moments from a coordinate (jitted)
Args:
cc (ndarray): galaxy peak coordinate
image (ndarray): exposure
Returns:
mm (ndarray): FPFS moments
"""
stamp = jax.lax.dynamic_slice(
image,
(cc[0] - self.ngrid // 2, cc[1] - self.ngrid // 2),
(self.ngrid, self.ngrid),
)
return self.measure_stamp(stamp)
[docs]
@partial(jax.jit, static_argnames=["self"])
def measure_stamp(self, data):
"""Measures the FPFS moments from a stamp (jitted)
Args:
data (ndarray): galaxy image array
Returns:
mm (ndarray): FPFS moments
"""
gal_fourier = jnp.fft.fftshift(jnp.fft.fft2(data))
gal_deconv = self.deconvolve(gal_fourier, prder=0.0, frder=1)
mm = self._itransform_chi(gal_deconv) # FPFS shapelets
mp = self._itransform_psi(gal_deconv) # FPFS detection
# jax.debug.print("debug: {}", mm)
return jnp.hstack([mm, mp])
[docs]
def get_results(self, out):
tps = self.chi_types + self.psi_types
res = np.rec.fromarrays(out.T, dtype=tps)
return res