Source code for fpfs.imgutil

# FPFS shear estimator
# Copyright 20210905 Xiangchong Li.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# python lib

import jax
import math
import numpy as np
import jax.numpy as jnp
from functools import partial


@partial(jax.jit, static_argnames=["ny", "nx", "klim", "return_grid"])
def _gauss_kernel_fft(ny, nx, sigma, klim, return_grid=False):
    """Generates a Gaussian kernel on grids for np.fft.fft transform
    (we always shift k=0 to (ngird//2, ngird//2)). The kernel is truncated at
    radius klim.

    Args:
        ny (int):    		    grid size in y-direction
        nx (int):    		    grid size in x-direction
        sigma (float):		    scale of Gaussian in Fourier space
        klim (float):           upper limit of k
        return_grid (bool):     return grids [True] or not [Flase]
                                [default: False]
    Returns:
        out (ndarray):          Gaussian on grids
        xgrid,ygrid (typle):    grids for [y, x] axes if return_grid
    """
    # mask
    x = jnp.fft.fftshift(jnp.fft.fftfreq(nx, 1 / np.pi / 2.0))
    y = jnp.fft.fftshift(jnp.fft.fftfreq(ny, 1 / np.pi / 2.0))
    ygrid, xgrid = jnp.meshgrid(y, x, indexing="ij")
    r2 = xgrid**2.0 + ygrid**2.0
    mask = (r2 <= klim**2).astype(jnp.float64)
    out = jnp.exp(-r2 / 2.0 / sigma**2.0) * mask
    if not return_grid:
        return out
    else:
        return out, (ygrid, xgrid)


@partial(jax.jit, static_argnames=["ny", "nx", "klim", "return_grid"])
def _gauss_kernel_rfft(ny, nx, sigma, klim, return_grid=False):
    """Generates a Gaussian kernel on grids for np.fft.rfft transform
    The kernel is truncated at radius klim.

    Args:
        ny (int):    		    grid size in y-direction
        nx (int):    		    grid size in x-direction
        sigma (float):		    scale of Gaussian in Fourier space
        klim (float):           upper limit of k
        return_grid (bool):     return grids or not
    Returns:
        out (ndarray):          Gaussian on grids
        ygrid, xgrid (typle):   grids for [y, x] axes, if return_grid
    """
    x = jnp.fft.rfftfreq(nx, 1 / np.pi / 2.0)
    y = jnp.fft.fftfreq(ny, 1 / np.pi / 2.0)
    ygrid, xgrid = jnp.meshgrid(y, x, indexing="ij")
    r2 = xgrid**2.0 + ygrid**2.0
    mask = (r2 <= klim**2).astype(jnp.float64)
    out = jnp.exp(-r2 / 2.0 / sigma**2.0) * mask
    if not return_grid:
        return out
    else:
        return out, (ygrid, xgrid)


[docs] @jax.jit def get_fourier_pow_fft(input_data): """Gets Fourier power function Args: input_data (ndarray): image array, centroid does not matter. Returns: out (ndarray): Fourier Power """ out = (jnp.abs(jnp.fft.fft2(input_data)) ** 2.0).astype(jnp.float64) out = jnp.fft.fftshift(out) return out
[docs] @jax.jit def get_fourier_pow_rfft(input_data): """Gets Fourier power function Args: input_data (ndarray): image array. The centroid does not matter. Returns: galpow (ndarray): Fourier Power """ out = (jnp.abs(jnp.fft.rfft2(input_data)) ** 2.0).astype(jnp.float64) return out
[docs] def detlets2d(ngrid, sigma, klim): """Generates shapelets function in Fourier space, chi00 are normalized to 1. This function only supports square stamps: ny=nx=ngrid. Args: ngrid (int): number of pixels in x and y direction sigma (float): scale of shapelets in Fourier space klim (float): upper limit of |k| Returns: psi (ndarray): 2d detlets basis in shape of [8,3,ngrid,ngrid] """ # Gaussian Kernel gauss_ker, (k2grid, k1grid) = _gauss_kernel_fft( ngrid, ngrid, sigma, klim, return_grid=True ) # for inverse Fourier transform gauss_ker = gauss_ker / ngrid**2.0 # for shear response q1_ker = (k1grid**2.0 - k2grid**2.0) / sigma**2.0 * gauss_ker q2_ker = (2.0 * k1grid * k2grid) / sigma**2.0 * gauss_ker # quantities for neighbouring pixels d1_ker = (-1j * k1grid) * gauss_ker d2_ker = (-1j * k2grid) * gauss_ker # initial output psi function ny, nx = gauss_ker.shape psi = np.zeros((8, 3, ny, nx), dtype=np.complex64) for _ in range(8): x = np.cos(np.pi / 4.0 * _) y = np.sin(np.pi / 4.0 * _) foub = np.exp(1j * (k1grid * x + k2grid * y)) psi[_, 0] = gauss_ker - gauss_ker * foub psi[_, 1] = q1_ker - (q1_ker + x * d1_ker - y * d2_ker) * foub psi[_, 2] = q2_ker - (q2_ker + y * d1_ker + x * d2_ker) * foub return psi
[docs] def shapelets2d(ngrid, nord, sigma, klim): """Generates complex shapelets function in Fourier space, chi00 are normalized to 1 [only support square stamps: ny=nx=ngrid] Args: ngrid (int): number of pixels in x and y direction nord (int): radial order of the shaplets sigma (float): scale of shapelets in Fourier space klim (float): upper limit of |k| Returns: chi (ndarray): 2d shapelet basis """ mord = nord gaufunc, (yfunc, xfunc) = _gauss_kernel_fft( ngrid, ngrid, sigma, klim, return_grid=True ) rfunc = np.sqrt(xfunc**2.0 + yfunc**2.0) # radius r2_over_sigma2 = (rfunc / sigma) ** 2.0 ny, nx = gaufunc.shape rmask = rfunc != 0.0 xtfunc = np.zeros((ny, nx), dtype=np.float64) ytfunc = np.zeros((ny, nx), dtype=np.float64) np.divide(xfunc, rfunc, where=rmask, out=xtfunc) # cos(phi) np.divide(yfunc, rfunc, where=rmask, out=ytfunc) # sin(phi) eulfunc = xtfunc + 1j * ytfunc # e^{jphi} # Set up Laguerre function lfunc = np.zeros((nord + 1, mord + 1, ny, nx), dtype=np.float64) lfunc[0, :, :, :] = 1.0 lfunc[1, :, :, :] = 1.0 - r2_over_sigma2 + np.arange(mord + 1)[None, :, None, None] # chi = np.zeros((nord + 1, mord + 1, ny, nx), dtype=np.complex64) for n in range(2, nord + 1): for m in range(mord + 1): lfunc[n, m, :, :] = (2.0 + (m - 1.0 - r2_over_sigma2) / n) * lfunc[ n - 1, m, :, : ] - (1.0 + (m - 1.0) / n) * lfunc[n - 2, m, :, :] for nn in range(nord + 1): for mm in range(nn, -1, -2): c1 = (nn - abs(mm)) // 2 d1 = (nn + abs(mm)) // 2 cc = math.factorial(c1) + 0.0 dd = math.factorial(d1) + 0.0 cc = cc / dd chi[nn, mm, :, :] = ( pow(-1.0, d1) * pow(cc, 0.5) * lfunc[c1, abs(mm), :, :] * pow(r2_over_sigma2, abs(mm) / 2) * gaufunc * eulfunc**mm * (1j) ** nn ) chi = chi.reshape(((nord + 1) ** 2, ny, nx)) / ngrid**2.0 return chi
[docs] def shapelets2d_real(ngrid, nord, sigma, klim): """Generates real shapelets function in Fourier space, chi00 are normalized to 1 [only support square stamps: ny=nx=ngrid] Args: ngrid (int): number of pixels in x and y direction nord (int): radial order of the shaplets sigma (float): scale of shapelets in Fourier space klim (float): upper limit of |k| Returns: chi_2 (ndarray): 2d shapelet basis w/ shape [n,ngrid,ngrid] name_s (list): A list of shaplet names w/ shape [n] """ # nm = m*(nnord+1)+n if nord == 4: # This setup is for shear response only # Only uses M00, M20, M22 (real and img) and M40, M42 indm = np.array([0, 10, 12, 20, 22])[:, None, None] name_s = ["m00", "m20", "m22c", "m22s", "m40", "m42c", "m42s"] ind_s = [ [0, False], [1, False], [2, False], [2, True], [3, False], [4, False], [4, True], ] elif nord == 6: # This setup is able to derive kappa response and shear response # Only uses M00, M20, M22 (real and img), M40, M42(real and img), M60 indm = np.array([0, 14, 16, 28, 30, 42])[:, None, None] name_s = ["m00", "m20", "m22c", "m22s", "m40", "m42c", "m42s", "m60"] ind_s = [ [0, False], [1, False], [2, False], [2, True], [3, False], [4, False], [4, True], [5, False], ] else: raise ValueError( "only support for nnord= 4 or nnord=6, but your input\ is nnord=%d" % nord ) # generate the complex shaplet functions chi = shapelets2d(ngrid, nord, sigma, klim)[indm] # transform to real shapelet functions chi_2 = np.zeros((len(name_s), ngrid, ngrid), dtype=np.float64) for i, ind in enumerate(ind_s): if ind[1]: chi_2[i] = np.float64(chi[ind[0]].imag) else: chi_2[i] = np.float64(chi[ind[0]].real) del chi return chi_2, name_s
[docs] def fpfs_bases(ngrid, nord, sigma, sigma_det=None, klim=3.15): """Returns the FPFS bases (shapelets and detectlets) Args: ngrid (int): stamp size nnord (int): the highest order of Shapelets radial components [default: 4] sigma (float): shapelet kernel scale in Fourier space sigma_det (float): detectlet kernel scale in Fourier space klim (float): upper limit of |k| [default 3.15] """ if sigma_det is None: sigma_det = sigma bfunc, bnames = shapelets2d_real( ngrid, nord, sigma, klim, ) psi = detlets2d( ngrid, sigma_det, klim, ) bnames = bnames + [ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v0_g1", "v1_g1", "v2_g1", "v3_g1", "v4_g1", "v5_g1", "v6_g1", "v7_g1", "v0_g2", "v1_g2", "v2_g2", "v3_g2", "v4_g2", "v5_g2", "v6_g2", "v7_g2", ] bfunc = np.vstack([bfunc, np.vstack(np.swapaxes(psi, 0, 1))]) return bfunc, bnames
[docs] def fit_noise_pf(ngrid, gal_pow, noise_mod, rlim): """ Fit the noise power from observed galaxy power Args: ngrid (int): number of pixels in x and y direction gal_pow (ndarray): galaxy Fourier power function Returns: out (ndarray): noise power to be subtracted """ rlim2 = int(max(ngrid * 0.4, rlim)) indx = np.arange(ngrid // 2 - rlim2, ngrid // 2 + rlim2 + 1) indy = indx[:, None] mask = np.ones((ngrid, ngrid), dtype=bool) mask[indy, indx] = False vl = gal_pow[mask] nl = noise_mod[:, mask] par = np.linalg.lstsq(nl.T, vl, rcond=None)[0] out = np.sum(par[:, None, None] * noise_mod, axis=0) return out
[docs] def pcaimages(xdata, nmodes): """Estimates the principal components of array list xdata Args: xdata (ndarray): input data array nmodes (int): number of pcs to keep Returns: out (ndarray): pc images, stds (ndarray): stds on the axis coeffs (ndarray): projection coefficient """ assert len(xdata.shape) == 3 # vectorize nobj, nn2, nn1 = xdata.shape dim = nn1 * nn2 # xdata is (x1,x2,x3..,xnobj).T [x_i is column vectors of data] xdata = xdata.reshape((nobj, dim)) # x_ave = xdata.mean(axis=0) # xdata = xdata-x_ave # x_ave = x_ave.reshape((1,nn2,nn1)) # Get covariance matrix data_mat = np.dot(xdata, xdata.T) / (nobj - 1) # Solve the Eigen function of the covariance matrix # e is eigen value and eig_vec is eigen vector # eig_vec: (p1, p2, .., pnobj) [p_i is column vectors of parameters] eig_val, eig_vec = np.linalg.eigh(data_mat) # The eigen vector tells the combination of ndata tmp = np.dot(eig_vec.T, xdata) # rank from maximum eigen value to minimum # and only keep the first nmodes pcs = tmp[::-1][:nmodes] eig_val = eig_val[::-1][: nmodes + 10] stds = np.sqrt(eig_val) out = pcs.reshape((nmodes, nn2, nn1)) coeffs = eig_vec.T[:nmodes] return out, stds, coeffs
[docs] def cut_img(img, rcut): """Cuts img into postage stamp with width=2rcut Args: img (ndarray): input image rcut (int): cutout radius Returns: out (ndarray): image in a stamp """ ngrid = img.shape[0] beg = ngrid // 2 - rcut end = beg + 2 * rcut out = img[beg:end, beg:end] return out
[docs] @jax.jit def get_pixel_detect_mask(sel, img, thres2): for ax in [-1, -2]: for shift in [-1, 1]: filtered = img - jnp.roll(img, shift=shift, axis=ax) sel = jnp.logical_and(sel, (filtered > thres2)) return sel
[docs] def find_peaks(img_conv, img_conv_det, thres, thres2=0.0, bound=20.0): """Detects peaks and returns the coordinates (y,x) This function does the pre-selection in Li & Mandelbaum (2023) Args: img_conv (ndarray): convolved image img_conv_det (ndarray): convolved image thres (float): detection threshold thres2 (float): peak identification difference threshold bound (float): minimum distance to the image boundary Returns: coord_array (ndarray): ndarray of coordinates [y,x] """ sel = img_conv > thres sel = get_pixel_detect_mask(sel, img_conv_det, thres2) data = jnp.array(jnp.int_(jnp.asarray(jnp.where(sel)))) del sel ny, nx = img_conv.shape y = data[0] x = data[1] msk = (y > bound) & (y < ny - bound) & (x > bound) & (x < nx - bound) data = data[:, msk] return data
[docs] @jax.jit def convolve2gausspsf(img_data, psf_data, sigmaf, klim): """This function convolves an image to transform the PSF to a Gaussian Args: img_data (ndarray): image data psf_data (ndarray): psf data sigmaf (float): sigma of Gaussian klim (float): radius for masking in Fourier space Returns: img_conv (ndarray): the reconvolved image """ ny, nx = psf_data.shape # Fourier transform psf_fourier = jnp.fft.rfft2(jnp.fft.ifftshift(psf_data)) # Gaussian kernel gauss_kernel = _gauss_kernel_rfft(ny, nx, sigmaf, klim, return_grid=False) # convolved images img_fourier = jnp.fft.rfft2(img_data) / psf_fourier * gauss_kernel img_conv = jnp.fft.irfft2(img_fourier, (ny, nx)) return img_conv
[docs] @jax.jit def get_klim(psf_array, sigma, thres=1e-20): """Gets klim, the region outside klim is supressed by the shaplet Gaussian kernel in FPFS shear estimation method; therefore we set values in this region to zeros Args: psf_array (ndarray): PSF's Fourier power or Fourier transform sigma (float): one sigma of Gaussian Fourier power thres (float): the threshold for a tuncation on Gaussian [default: 1e-20] Returns: klim (float): the limit radius """ ngrid = psf_array.shape[0] def cond_fun(dist): v1 = abs( jnp.exp(-(dist**2.0) / 2.0 / sigma**2.0) / psf_array[ngrid // 2 + dist, ngrid // 2] ) v2 = abs( jnp.exp(-(dist**2.0) / 2.0 / sigma**2.0) / psf_array[ngrid // 2, ngrid // 2 + dist] ) return jax.lax.cond( v1 < v2, v1, lambda x: x > thres, v2, lambda x: x > thres, ) def body_fun(dist): return dist + 1 klim = jax.lax.while_loop( cond_fun=cond_fun, body_fun=body_fun, init_val=ngrid // 5, ) return klim
[docs] def truncate_square(arr, rcut): if len(arr.shape) != 2 or arr.shape[0] != arr.shape[1]: raise ValueError("Input array must be a 2D square array") ngrid = arr.shape[0] arr[: ngrid // 2 - rcut, :] = 0 arr[ngrid // 2 + rcut :, :] = 0 arr[:, : ngrid // 2 - rcut] = 0 arr[:, ngrid // 2 + rcut :] = 0 return
[docs] def truncate_circle(arr, rcut): if len(arr.shape) != 2 or arr.shape[0] != arr.shape[1]: raise ValueError("Input array must be a 2D square array") ngrid = arr.shape[0] y, x = np.ogrid[0:ngrid, 0:ngrid] center_x, center_y = ngrid // 2, ngrid // 2 # Compute the squared distance to the center distance_squared = (x - center_x) ** 2 + (y - center_y) ** 2 # Mask values outside the circle arr[distance_squared > rcut**2] = 0.0 return