Source code for blip.src.faster_geometry.interface

"""Interface with the rest of BLIP."""

import functools
import logging
from collections import namedtuple

import jax
from jax import numpy as jnp, vmap, jit
import chex
import numpy as np
import healpy as hp
from scipy.special import sph_harm_y

from tqdm import tqdm

from .core import mich_response_unconvolved
from .orbit import compute_orbits
from .interp import get_response_interpolator, get_sparse_tf_grid
from .util import get_vecs_all_sky
from .const import INTERPOLATION_ALLOWED, LOG_PERFORMANCE, FSTAR
from ..submodel import submodel

__all__ = ["calculate_response_functions"]

logger = logging.getLogger(__name__)

_GridInfo = namedtuple(
    "GridInfo",
    ["nf", "nt", "nf_s", "nt_s", "times", "freqs", "times_sparse", "freqs_sparse"],
)
_SkyInfo = namedtuple("SkyInfo", ["nside", "npix", "active_pixels_idx", "dOmega"])


[docs] def calculate_response_functions(freqs, times, submodels, params, plot_flag=False): """ Compute SGWB responses for each submodel and write it to the submodel attributes. This procedure avoids redundant calculations by only computing the response for a given pixel once. It also uses a sparse time-frequency grid and linearly interpolates the results in-between. It mirrors the functionality in fast_geometry.calculate_response_functions(). Parameters ---------- freqs : array (nfreqs,) frequencies. times : array (ntimes,) times. submodels : list[submodel] The submodels whose response matrices will be computed. They must be in pixel basis and have a skymap attribute, not assumed normalized. These objects also act as output parameters: after the procedure has run, each of them will receive a response matrix as an attribute, which is an array of shape (3, 3, nfreqs, ntimes). The attributes used for the output depend on the value of plot_flag: the response matrix will be written to sm.response_mat and sm.inj_response_mat if it is False, or to sm.fdata_response_mat if it is True. params: dict Parameter dictionary from parse_config(). plot_flag : bool, optional If True, don't write the output response matrix to sm.response_mat or sm.inj_response_mat, but to submodel.fdata_response_mat. This is useful for plotting simulation and analysis models together (hence the name). Defaults to False. Warnings -------- This is experimental, tested only for healpix responses. """ chex.assert_rank([freqs, times], 1) for sm in submodels: assert isinstance(sm, submodel) assert hasattr(sm, "has_map") assert sm.has_map or sm.fullsky del sm # deleted to prevent accidental use out of scope as in issue #28 assert params["lisa_config"] == "orbiting" ############## Set up vectorization, interpolation, orbits ############ # Vectorize twice: on times and on frequencies. # Sky directions are done sequentially _mru_vect = vmap( mich_response_unconvolved, (0, None, None, None), out_axes=2 ) # (3, 3, ntimes) _mru_vec2 = vmap( _mru_vect, (None, 0, None, None), out_axes=2 ) # (3, 3, nfreqs, ntimes) mru_vec2 = jit(_mru_vec2) # Set up sparse grid for interpolation if INTERPOLATION_ALLOWED: logger.info("Using interpolation") times_s, freqs_s = get_sparse_tf_grid(times, freqs) interpolator = get_response_interpolator(times, freqs, times_s, freqs_s) interpolator = jit(interpolator) else: times_s, freqs_s = times, freqs interpolator = lambda x: x # noqa: E731 nt, nf, nt_s, nf_s = len(times), len(freqs), len(times_s), len(freqs_s) nside = params["nside"] npix = hp.nside2npix(nside) dOmega = 4 * np.pi / npix gridinfo = _GridInfo(nf, nt, nf_s, nt_s, times, freqs, times_s, freqs_s) orbits = compute_orbits(times_s) ################ compute responses for pixels with nonzero power ############# active_pixels_idx, active_pixels_vecs = _find_nonzero_pixels(submodels, nside, npix) skyinfo = _SkyInfo(nside, npix, active_pixels_idx, dOmega) if LOG_PERFORMANCE: _log_performance_analysis(mru_vec2, gridinfo, skyinfo, orbits) print("Computing LISA response functions...") resp_s = [] for i, pix_vec in tqdm( zip(active_pixels_idx, active_pixels_vecs), total=len(active_pixels_idx), desc="response", unit="pixel", ): # We block on this computation before dispatching the next one to the GPU, # otherwise the progress bar is meaningless. A possible optimization would be # to dispatch the next iteration before blocking on the current one. _r = mru_vec2(times_s, freqs_s, pix_vec, orbits) resp_s.append(_r.block_until_ready()) chex.assert_shape(resp_s[0], (3, 3, nf_s, nt_s)) chex.assert_equal_shape(resp_s) ################# Integrate on sky for each submodel. ################ # We do this sequentially with a python for loop and using plain numpy (not jax) # arrays. The point of this is to avoid allocating memory for the whole integrand, a # large rank-5 tensor. for sm in submodels: # anisotropic submodels (with maps) if sm.has_map: # submodels with fixed spatial templates (output has 4 dims) if not sm.parameterized_map: rmat, postf_dims = _do_fixed_submodel( sm, gridinfo, skyinfo, interpolator, resp_s ) chex.assert_shape(rmat, (3, 3, nf, nt)) assert postf_dims == 1 # submodels with parameterized spatial templates (output has 5 dims) else: rmat, postf_dims = _do_parameterized_submodel( sm, gridinfo, skyinfo, interpolator, resp_s ) chex.assert_shape(rmat, (3, 3, nf, nt, None)) assert postf_dims == 2 # isotropic SGWB case (effectively a fixed template) elif sm.spatial_model_name == "isgwb": rmat, postf_dims = _do_isotropic_submodel( gridinfo, skyinfo, interpolator, resp_s ) chex.assert_shape(rmat, (3, 3, nf, nt)) assert postf_dims == 1 else: assert False rmat = _mich_to_tdi(rmat, freqs, params, postf_dims) # Assert postconditions for clarity: output shape and dtype. output_shape = (3, 3, nf, nt) if sm.has_map and sm.parameterized_map: if sm.basis == "pixel": sky_size = np.sum(sm.mask_idx) else: sky_size = (sm.almax + 1) ** 2 output_shape = (3, 3, nf, nt, sky_size) assert rmat.shape == output_shape assert rmat.dtype == complex ##### add reference to the response matrix in the submodel ##### if plot_flag: sm.fdata_response_mat = rmat else: sm.response_mat = rmat sm.inj_response_mat = rmat # deprecated
def _find_nonzero_pixels(submodels, nside, npix): is_fullsky = any([getattr(sm, "fullsky", None) for sm in submodels]) if is_fullsky: active_pixels_idx = jnp.arange(npix) else: active_pixels_idx_sm = [ _intarray_to_set(jnp.nonzero(sm.skymap)[0]) for sm in submodels if sm.has_map ] active_pixels_idx = jnp.array( sorted( list(functools.reduce(lambda x, y: x.union(y), active_pixels_idx_sm)) ), dtype=int, ) assert active_pixels_idx.dtype == int active_pixels_vecs = get_vecs_all_sky(nside)[active_pixels_idx] chex.assert_shape(active_pixels_vecs, (None, 3)) return active_pixels_idx, active_pixels_vecs def _do_fixed_submodel(sm, gridinfo, skyinfo, interpolator, resp_s): nf, nt = gridinfo.nf, gridinfo.nt npix, nside = skyinfo.npix, skyinfo.nside dOmega = skyinfo.dOmega active_pixels_idx = skyinfo.active_pixels_idx rmat = np.zeros((3, 3, nf, nt), dtype=complex) postf_dims = 1 # nb of tensor dimensions after freqs if sm.basis == "pixel": chex.assert_shape(sm.skymap, (npix,)) ## normalize skymap such that it integrates to 1 over the sky skymap_normalized = sm.skymap / (jnp.sum(sm.skymap) * dOmega) for i, response_sparse in zip(active_pixels_idx, resp_s): response = interpolator(response_sparse) rmat += skymap_normalized[i] * response * dOmega elif sm.basis == "sph": alm_size = (sm.almax + 1) ** 2 ## angular coordinates of pixel indices theta, phi = hp.pix2ang(nside, active_pixels_idx) Ylms = np.zeros((npix, alm_size), dtype="complex") ## Get the spherical harmonics for ii in range(alm_size): lval, mval = sm.idxtoalm(sm.almax, ii) Ylms[:, ii] = sph_harm_y(mval, lval, theta, phi) ## check that the Ylms have the right number of pixels, sph terms chex.assert_shape(Ylms, (npix, alm_size)) for i, response_sparse in zip(active_pixels_idx, resp_s): response = interpolator(response_sparse) # from Ylms, get skymap intensity for this pixel pix_sph_sum = np.sum(Ylms[i, :] * sm.alms_inj) rmat += pix_sph_sum * response * dOmega else: assert False return rmat, postf_dims def _do_parameterized_submodel(sm, gridinfo, skyinfo, interpolator, resp_s): nf, nt = gridinfo.nf, gridinfo.nt npix, nside = skyinfo.npix, skyinfo.nside dOmega = skyinfo.dOmega active_pixels_idx = skyinfo.active_pixels_idx if sm.basis == "pixel": sky_size = np.sum(sm.mask_idx) rmat = np.zeros((3, 3, nf, nt, sky_size), dtype=complex) postf_dims = 2 # nb of tensor dimensions after freqs for i, response_sparse in zip(active_pixels_idx, resp_s): if i in sm.mask_idx: response = interpolator(response_sparse) rmat[..., i] = response elif sm.basis == "sph": alm_size = (sm.almax + 1) ** 2 ## angular coordinates of pixel indices theta, phi = hp.pix2ang(nside, active_pixels_idx) Ylms = np.zeros((npix, alm_size), dtype="complex") ## Get the spherical harmonics for ii in range(alm_size): lval, mval = sm.idxtoalm(sm.almax, ii) Ylms[:, ii] = sph_harm_y(mval, lval, theta, phi) ## check that the Ylms have the right number of pixels, sph terms chex.assert_shape(Ylms, (npix, alm_size)) ## 3 x 3 x frequency x time x Ylms rmat = np.zeros((3, 3, nf, nt, alm_size), dtype="complex") postf_dims = 2 ## time x Ylms ## loop over pixels, interpolating the response as we go for i, response_sparse in zip(active_pixels_idx, resp_s): response = interpolator(response_sparse) rmat += Ylms[None, None, None, None, i, :] * response[..., None] * dOmega else: assert False return rmat, postf_dims def _do_isotropic_submodel(gridinfo, skyinfo, interpolator, resp_s): nf, nt = gridinfo.nf, gridinfo.nt dOmega = skyinfo.dOmega active_pixels_idx = skyinfo.active_pixels_idx rmat = np.zeros((3, 3, nf, nt), dtype=complex) postf_dims = 1 ## just time ## loop over pixels, interpolating the response as we go for i, response_sparse in zip(active_pixels_idx, resp_s): response = interpolator(response_sparse) rmat += (1 / (4 * jnp.pi)) * response * dOmega return rmat, postf_dims def _mich_to_tdi(rmat, freqs, params, postf_dims): if params["tdi_lev"] == "xyz": mich_to_x1 = 4 * jnp.sin(freqs / FSTAR) ** 2 rmat = ( rmat * mich_to_x1[ np.newaxis, np.newaxis, :, *[np.newaxis for i in range(postf_dims)] ] ) elif params["tdi_lev"] == "aet": # TODO raise NotImplementedError return rmat def _log_performance_analysis(mru_vec2, gridinfo, skyinfo, orbits): times_sparse = gridinfo.times_sparse freqs_sparse = gridinfo.freqs_sparse freqs = gridinfo.freqs nt_s = gridinfo.nt_s nf_s = gridinfo.nf_s ## check runtime, memory use _compiled = mru_vec2.lower( times_sparse, freqs_sparse, jnp.zeros(3), orbits ).compile() logger.debug("freq range %f - %f", freqs[0], freqs[-1]) logger.debug("sparse ntimes = %d, sparse nfreqs = %d", nt_s, nf_s) logger.debug("response execution cost analysis: %s", _compiled.cost_analysis()) logger.debug("response memory cost analysis: %s", _compiled.memory_analysis()) # From a CPU test: # output size = 144 bytes / time / freq / pixel (one complex 3x3 matrix) # bytes accessed = 3049 bytes / time / freq / pixel # flops = 1378 flop / time / freq / pixel # NOTE(solano) the usage of 3x3 complex matrices in our context is a waste of RAM. # We assume equal and constant LTTs, so our SGWB can be perfectly described by two # real series: the AA and EE time-varying PSDs. That would reduce the output from # 144 bytes to 16 bytes per time per frequency per pixel (9x compression). And # that's still using double precision, which we don't need for the response matrix # output. Using single precision we could do 144 -> 8 bytes, an 18x compression. # NOTE(awc) However, the matrix inversions in the likelihood require double # precision, and the spherical harmonic search requires the complex responses. def _intarray_to_set(a: jax.Array) -> set: assert a.dtype == int return set([int(v) for v in a.ravel()])