Source code for blip.src.faster_geometry.core

"""
Core logic for LISA response calculations.

Here is the logic, top-down.

- :func:`mich_response_unconvolved` computes the response correlation matrix. This is
  the ultimate goal.
- It calls :func:`mich_antenna_pattern` which computes the LISA response for individual
  TDI spacecraft or TDI channels, and individual GW polarization modes.
- :func:`mich_antenna_pattern` simply multiplies a detector tensor by a polarization
  tensor.
- To compute the detector tensor you need orbital information as well as the
  :func:`timing transfer function <timing_transfer_fn>`.
- To define a polarization tensor you need an orthonormal basis at the tangent space to
  each spacecraft (:func:`get_ortho_basis_ecliptic_3d`).
"""

########### Implementation details ##########
# This module leans heavily on JAX automatic vectorization and JIT compilation. All the
# functions are written for the simplest possible array shapes (mostly scalars), making
# them easier to check for correctness. Trace-time assertions, mostly on array shapes,
# have also been placed as strong comments all throughout the code.

# The formulas implemented here are exactly the ones in the original BLIP paper
# Banagiri+21, with two exceptions:
# - the mistaken factor of 1/4pi is removed from eq. (18) which defines the response
#   matrix;
# - all throughout this module, the sign of n is flipped wrt. the paper, i.e. the vector
#   n here means the direction towards the GW source, not from the source.

# TODO check if the sign of n conflicts with the rest of BLIP.


from jax import numpy as jnp
import chex

from .orbit import get_arm_orientations, get_sc_positions
from .const import FSTAR, CLIGHT

__all__ = [
    "mich_response_unconvolved",
    "mich_antenna_pattern",
    "mich_detector_tensor",
    "get_ortho_basis_ecliptic_3d",
    "timing_transfer_fn",
]


[docs] def mich_response_unconvolved(t, f, n, orbits): r""" Unconvolved Michelson (TDI gen 0) sky SGWB response. Parameters ---------- t : float time f : float frequency n : array (3,) normalized vector in the direction of the GW source. orbits : tuple orbital information returned by compute_orbits(). Returns ------- complex array (3, 3) Unconvolved response matrix for the three data channels. Notes ----- The quantity computed here is exactly .. math:: \frac{1}{2} \sum_{A=+,\times}\left(F_I^A(f, \mathbf{n})^* F_J^A(f, \mathbf{n}) \right) where :math:`I` and :math:`J` stand for TDI channels, and :math:`F_I^A` are antenna pattern functions. This is integrated against the sky map to produce the GW time-frequency correlation matrix. The convention here agrees with Criswell+25 eq. (6) (but for a complex conjugate), which is a corrected version of Banagiri+21 eq. (18). """ chex.assert_shape([t, f], ()) chex.assert_shape(n, (3,)) res = jnp.zeros((3, 3), dtype=complex) # This loop intentionally uses python control flow so that it is # unrolled in tracing and the channels (c1, c2) are trace-time known. for c1 in range(3): for c2 in range(c1, 3): fp1 = mich_antenna_pattern(t, f, n, "plus", c1, orbits) fp2 = mich_antenna_pattern(t, f, n, "plus", c2, orbits) fc1 = mich_antenna_pattern(t, f, n, "cross", c1, orbits) fc2 = mich_antenna_pattern(t, f, n, "cross", c2, orbits) chex.assert_shape([fp1, fp2, fc1, fc2], ()) res = res.at[c1, c2].set(0.5 * (fp1.conj() * fp2 + fc1.conj() * fc2)) if c1 != c2: res = res.at[c2, c1].set(res[c1, c2].conj()) chex.assert_shape(res, (3, 3)) return res
[docs] def mich_antenna_pattern(t, f, n, polarization: str, channel, orbits): """ Compute Michelson (TDI gen 0) antenna pattern. Checked against Banagiri+21 and Romano & Cornish 2017. Parameters ---------- t : float time f : float frequency n : array (3,) Unit vector in the direction of the GW source. polarization : str Should be "plus" or "cross". Must be known at JAX trace time. channel : int Channel index 0, 1, 2. Must be known at JAX trace time. orbits : tuple orbital information returned by compute_orbits(). Returns ------- complex The antenna pattern. """ # polarization and channel must be trace-time known chex.assert_shape([t, f, channel], ()) chex.assert_shape(n, (3,)) assert polarization in ["plus", "cross"] assert 0 <= channel and channel < 3 # lam, beta = ecliptic coordinates (lon, lat) of n n = n / jnp.linalg.norm(n) beta = jnp.arcsin(n[2]) lam = jnp.atan2(n[1], n[0]) lam = jnp.where(lam < 0, lam + 2 * jnp.pi, lam) # polarization tensor, Romano & Cornish 2017 eq 2.3 _, ell, emm = get_ortho_basis_ecliptic_3d(lam, beta) if polarization == "plus": pol_tens = jnp.outer(ell, ell) - jnp.outer(emm, emm) else: pol_tens = jnp.outer(ell, emm) + jnp.outer(emm, ell) # detector tensor sc = channel + 1 u, v = get_arm_orientations(t, sc, orbits) r = get_sc_positions(t, orbits)[sc - 1] det_tens = mich_detector_tensor(f, u, v, n, r) chex.assert_shape([det_tens, pol_tens], (3, 3)) res = jnp.tensordot(det_tens, pol_tens) chex.assert_shape(res, ()) return res
[docs] def mich_detector_tensor(f, u, v, n, r): """ Michelson channel detector tensor. Checked against Banagiri+21 eq (15). Parameters ---------- f : float frequency in Hz u : array (3,) normalized vector in the direction of the first arm v : array (3,) normalized vector in the direction of the second arm n : array (3,) normalized vector in the direction of the GW source r : array (3,) position of vertex S/C in barycentric ecliptic cartesian coordinates Returns ------- complex array (3, 3) detector tensor """ chex.assert_shape(f, ()) chex.assert_shape([u, v, n, r], (3,)) uu = jnp.outer(u, u) vv = jnp.outer(v, v) chex.assert_shape([uu, vv], (3, 3)) un = jnp.dot(u, n) vn = jnp.dot(v, n) nr = jnp.dot(n, r) omega = 2 * jnp.pi * f chex.assert_shape([un, vn, nr, omega], ()) tun = timing_transfer_fn(f, un) tvn = timing_transfer_fn(f, vn) chex.assert_shape([tun, tvn], ()) factor = jnp.exp(1j * omega * nr / CLIGHT) result = 0.5 * factor * (tun * uu - tvn * vv) chex.assert_shape(result, (3, 3)) return result
[docs] def timing_transfer_fn(f, costheta): """ Timing transfer function for two-way photon propagation. Checked against Banagiri+21 eq (16) and Cornish & Rubbo 2003 eq (37). Also agrees with Romano & Cornish 2017 eq (5.27) up to a constant 2L/c. This seems due to the conversion between strain and timing measurements, eq (5.4) in the living review. Parameters ---------- f : float frequency in Hz costheta : float cosine of angle between arm and sky direction Returns ------- complex the transfer function. """ chex.assert_shape([f, costheta], ()) f0 = f / (2 * FSTAR) s1 = _sinc(f0 * (1 + costheta)) s2 = _sinc(f0 * (1 - costheta)) e1 = jnp.exp(-1j * f0 * (3 - costheta)) e2 = jnp.exp(-1j * f0 * (1 - costheta)) res = 0.5 * (s1 * e1 + s2 * e2) chex.assert_shape(res, ()) return res
[docs] def get_ortho_basis_ecliptic_3d(lam, beta): """ Get right-handed orthonormal basis (n, l, m). This is the basis in Romano & Cornish 2017 eq (2.4). Parameters ---------- lam : float ecliptic longitude beta : float ecliptic latitude Returns ------- tuple (n, l, m) A tuple of arrays of shape (3,). """ chex.assert_shape([lam, beta], ()) theta, phi = jnp.pi / 2 - beta, lam ct, st = jnp.cos(theta), jnp.sin(theta) cp, sp = jnp.cos(phi), jnp.sin(phi) enn = jnp.array([st * cp, st * sp, ct]) ell = jnp.array([ct * cp, ct * sp, -st]) emm = jnp.array([-sp, cp, 0]) chex.assert_shape([enn, ell, emm], (3,)) return enn, ell, emm
# Surprisingly, this does not exist in jax.scipy.special def _sinc(x): # Inner select avoids NaN when differentiating at x=0 _x = jnp.select([x != 0, True], [x, 1.0]) return jnp.select([x != 0, True], [jnp.sin(_x) / _x, 1.0])