Source code for blip.src.models

import os
import shutil
import pickle

import numpy as np
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
import healpy as hp
import logging
import time
from blip.src.submodel import submodel
from blip.src.utils import log_manager, gen_suffixes, catch_color_duplicates
from blip.src.fast_geometry import fast_geometry
from blip.src.faster_geometry import calculate_response_functions

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
jax.config.update("jax_enable_x64", True)


###################################################
###      UNIFIED MODEL PRIOR & LIKELIHOOD       ###
###################################################

[docs] @register_pytree_node_class class Model(): """ Analysis model. This is a container of :class:`submodels <blip.src.submodel.submodel>` with a prior and a likelihood that can be sampled. Parameters ---------- params, inj : dict Configuration dictionaries from :func:`parse_config <blip.config.parse_config>`. fs : array (nfreqs,) Frequency grid in Hz. f0 : array (nfreqs,) Frequency grid scaled by transfer frequency (f0 = fs/2*FSTAR). tsegmid : array (ntimes,) Time grid in seconds. rmat : complex array (nfreqs, ntimes, 3, 3) Data correlation matrix. """ # TODO deduce (fs, f0, tsegmid) from (params, inj) def __init__(self,params,inj,fs,f0,tsegmid,rmat): self.fs = fs self.f0 = f0 self.tsegmid = tsegmid self.params = params self.inj = inj base_component_names = [spec.raw_name for spec in params["model"]] self.submodel_names = base_component_names suffixes = gen_suffixes(base_component_names) self.submodels = {} self.Npar = 0 self.parameters = {} all_parameters = [] spectral_parameters = [] spatial_parameters = [] self.blm_phase_idx = [] for submodel_spec, suffix in zip(params["model"], suffixes): submodel_name = submodel_spec.raw_name sm = submodel(params,inj,submodel_spec,fs,f0,tsegmid,suffix=suffix) self.submodels[submodel_name] = sm if hasattr(sm,"blm_phase_idx"): for ii in sm.blm_phase_idx: self.blm_phase_idx.append(self.Npar+sm.blm_start+ii) # if sm.Npar==0: # sm.fixed_cov = ... ## add handling for 0-parameter, non-noise models here (both spatial and spectral models fixed) self.Npar += sm.Npar self.parameters[submodel_name] = sm.parameters spectral_parameters += sm.spectral_parameters spatial_parameters += sm.spatial_parameters all_parameters += sm.parameters self.parameters['spectral'] = spectral_parameters self.parameters['spatial'] = spatial_parameters self.parameters['all'] = all_parameters ## Having initialized all the components, now compute the LISA response functions # t1 = time.time() # fast_rx = fast_geometry(self.params) # fast_rx.calculate_response_functions(self.f0,self.tsegmid,[self.submodels[smn] for smn in self.submodel_names if smn !='noise'],self.params['tdi_lev']) # t2 = time.time() # print("Time elapsed for calculating the LISA response functions for all submodels via joint computation is {} s.".format(t2-t1)) # ## deallocate to save on memory now that the response functions have been calculated and stored elsewhere # del fast_rx ## update colors as needed catch_color_duplicates(self) ## assign reference to data for use in likelihood self.rmat = rmat return # @jax.jit
[docs] def prior(self,unit_theta): ''' Unified prior function to interatively perform prior draws for each submodel in the proper order Arguments ---------------- unit_theta (array) : draws from the unit cube Returns ---------------- theta (list) : transformed prior draws for all submodels in sequence ''' theta = [] start_idx = 0 for sm_name in self.submodel_names: sm = self.submodels[sm_name] theta += sm.prior(unit_theta[start_idx:(start_idx+sm.Npar)]) start_idx += sm.Npar if len(theta) != len(unit_theta): raise ValueError("Input theta does not have same length as output theta, something has gone wrong!") return theta
# @jax.jit
[docs] def likelihood(self,theta): ''' Unified likelihood function to compare the combined covariance contributions of a generic set of noise/SGWB models to the data. Arguments ---------------- theta (list) : transformed prior draws for all submodels in sequence Returns ---------------- loglike (float) : resulting joint log likelihood ''' start_idx = 0 for i, sm_name in enumerate(self.submodel_names): sm = self.submodels[sm_name] if sm.Npar == 0: theta_i = None else: theta_i = theta[start_idx:(start_idx+sm.Npar)] start_idx += sm.Npar if i==0: cov_mat = sm.cov(theta_i) else: cov_mat = cov_mat + sm.cov(theta_i) ## change axis order to make taking an inverse easier cov_mat = jnp.moveaxis(cov_mat, [-2, -1], [0, 1]) ## take inverse and determinant inv_cov, det_cov = bespoke_inv(cov_mat) logL = -jnp.einsum('ijkl,ijkl', inv_cov, self.rmat) - jnp.einsum('ij->', jnp.log(jnp.pi * self.params['seglen'] * jnp.abs(det_cov))) loglike = jnp.real(logL) return loglike
## this allows for jax/numpyro to properly perform jitting of the class ## all attributes of the model class should be static ## may need to tweak this if/when we implement any kind of RJMCMC approach
[docs] def tree_flatten(self): children = [] # arrays / dynamic values aux_data = {'params':self.params,'inj':self.inj,'fs':self.fs,'f0':self.f0,'tsegmid':self.tsegmid,'rmat':self.rmat} # static values return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data)
################################################### ### UNIFIED INJECTION INFRASTRUCTURE ### ###################################################
[docs] class Injection(): """ Simulation model. This is a container of :class:`submodels <blip.src.submodel.submodel>` used for synthesizing LISA data. Parameters ---------- params, inj : dict Configuration dictionaries from :func:`parse_config <blip.config.parse_config>`. fs : array (nfreqs,) Frequency grid in Hz. f0 : array (nfreqs,) Frequency grid scaled by transfer frequency (f0 = fs/2*FSTAR). tsegmid : array (ntimes,) Time grid in seconds. """ # TODO deduce (fs, f0, tsegmid) from (params, inj) def __init__(self,params,inj,fs,f0,tsegmid): self.params = params self.inj = inj self.frange = fs self.f0 = f0 self.tsegmid = tsegmid ## separate into components self.component_specs = inj['injection'] self.component_names = [spec.raw_name for spec in self.component_specs] N_inj = len(self.component_specs) ### commenting this out because we're switching to active specification of duplicates in the params file ## check for and differentiate duplicate injections ## this will append 1 (then 2, then 3, etc.) to any duplicate component names ## we will also generate appropriate variable suffixes to use in plots, etc.. # self.component_names = catch_duplicates(base_component_names) ## it's useful to have a version of this without the detector noise self.sgwb_component_names = [name for name in self.component_names if name!='noise'] suffixes = gen_suffixes(self.component_names) ## initialize components self.components = {} self.truevals = {} ## step through and build components ## parallelization has been depreciated now that the response function calculations are handled elsewhere for i, (component_spec, suffix) in enumerate(zip(self.component_specs,suffixes)): print("Building injection for {} (component {} of {})...".format(component_spec.raw_name,i+1,N_inj)) cm = submodel(params,inj,component_spec,fs,f0,tsegmid,injection=True,suffix=suffix) self.components[component_spec.raw_name] = cm self.truevals[component_spec.raw_name] = cm.truevals if cm.has_map: self.plot_skymaps(component_spec.raw_name) ## Having initialized all the components, now compute the LISA response functions if self.inj['parallel_inj'] and self.inj['response_nthread']>1: rx_nthreads = self.inj['response_nthread'] else: rx_nthreads = 1 t1 = time.time() submodels_sgwb = [self.components[cmn] for cmn in self.sgwb_component_names] if params["faster_geometry"]: calculate_response_functions(fs, self.tsegmid, submodels_sgwb, params) else: fast_rx = fast_geometry(self.params,nthreads=rx_nthreads) fast_rx.calculate_response_functions(self.f0,self.tsegmid,submodels_sgwb,self.params['tdi_lev']) t2 = time.time() print("Time elapsed for calculating the LISA response functions for all components via joint computation is {} s.".format(t2-t1)) ## initialize default plotting lower ylim self.plot_ylim = None ## update colors as needed catch_color_duplicates(self)
[docs] def add_component(self,name_args): ''' Wrapper function for the injection component creation process, to allow for parallelization. Arguments ------------------ name_args (tuple) : (component_name,suffix) for one component Returns ------------------ cm (submodel object) : Injection component ''' component_name, suffix = name_args cm = submodel(self.params,self.inj,component_name,self.frange,self.f0,self.tsegmid,injection=True,suffix=suffix) print("Injection component build complete for component: {}".format(component_name)) return cm
# def compute_convolved_spectra(self,component_name,fs_new=None,channels='11',return_fs=False,imaginary=False): # ''' # Wrapper to return the frozen injected detector-convolved GW spectra for the desired channels. # # Useful note - these frozen spectra are computed in diag_spectra(), as they are calculated and saved at the analysis frequencies. # # Also note that this is meant for plotting purposes only, and includes interpolation/absolute values that are not desirable in a data generation/analysis environment. # # Arguments # ----------- # component_name (str) : the name (key) of the Injection component to use. # fs_new (array) : If desired, frequencies at which to interpolate the convolved PSD # channels (str) : Which channel cross/auto-correlation PSD to plot. Default is '11' auto-correlation, i.e. XX for XYZ, 11 for Michelson, AA for AET. # return_fs (bool) : If True, also returns the frequencies at which the PSD has been evaluated. Default False. # imaginary (bool) : If True, returns the magnitude of the imaginary component. Default False. # # Returns # ----------- # PSD (array) : Power spectral density of the specified channels' auto/cross-correlation at the desired frequencies. # fs (array, optional) : The PSD frequencies, if return_fs==True. # # ''' # # cm = self.components[component_name] # ## split the channel indicators # c1_idx, c2_idx = int(channels[0]) - 1, int(channels[1]) - 1 # # if not imaginary: # PSD = np.abs(np.real(cm.frozen_convolved_spectra[c1_idx,c2_idx,:])) # else: # PSD = 1j * np.abs(np.imag(cm.frozen_convolved_spectra[c1_idx,c2_idx,:])) # # ## populations need some finessing due to frequency subtleties # if hasattr(cm,"ispop") and cm.ispop: # fs = cm.population.frange_true # if (fs_new is not None) and not np.array_equal(fs_new,cm.population.frange_true): # with log_manager(logging.ERROR): # PSD_interp = interp1d(fs,PSD,bounds_error=False,fill_value=0) # PSD = PSD_interp(fs_new) # fs = fs_new # else: # fs = self.frange # ## there is no way to compute the convolved injected spectra once the injected response functions have been flushed # ## we have saved them, however, and can either just use the saved frozen spectra or interpolate to a new frequency grid # ## WARNING: interpolation will likely result in low fidelity at f < 3e-4 Hz. # if fs_new is not None: # with log_manager(logging.ERROR): # PSD_interp = interp1d(fs,np.log10(PSD)) # PSD = 10**PSD_interp(fs_new) # fs = fs_new # # if return_fs: # return fs, PSD # else: # return PSD
[docs] def compute_convolved_spectra(self,component_name,fs_new=None,channels='11',return_fs=False,imaginary=False): ''' Wrapper to return the frozen injected detector-convolved GW spectra for the desired channels. Useful note - these frozen spectra are computed in diag_spectra(), as they are calculated and saved at the analysis frequencies. Also note that this is meant for plotting purposes only, and includes interpolation/absolute values that are not desirable in a data generation/analysis environment. Arguments ----------- component_name (str) : the name (key) of the Injection component to use. fs_new (array) : If desired, frequencies at which to interpolate the convolved PSD channels (str) : Which channel cross/auto-correlation PSD to plot. Default is '11' auto-correlation, i.e. XX for XYZ, 11 for Michelson, AA for AET. return_fs (bool) : If True, also returns the frequencies at which the PSD has been evaluated. Default False. imaginary (bool) : If True, returns the magnitude of the imaginary component. Default False. Returns ----------- PSD (array) : Power spectral density of the specified channels' auto/cross-correlation at the desired frequencies. fs (array, optional) : The PSD frequencies, if return_fs==True. ''' cm = self.components[component_name] ## split the channel indicators c1_idx, c2_idx = int(channels[0]) - 1, int(channels[1]) - 1 ## simulated data frequencies if fs_new == 'data': fs = cm.fdata PSD_complex = cm.fdata_convolved_spectra[c1_idx,c2_idx,:] ## all other cases start from the original injected frequencies else: fs = self.frange PSD_complex = cm.frozen_convolved_spectra[c1_idx,c2_idx,:] ## handle complex spectra as desired if not imaginary: PSD = np.abs(np.real(PSD_complex)) else: PSD = 1j * np.abs(np.imag(PSD_complex)) ## estimate spectra at new frequencies -- WARNING: requires interpolation, usually produces low-fidelity results ## only really useful for quick checks and visualization, NOT for analysis purposes! if fs_new is not None and fs_new != 'data': ## populations need some finessing due to frequency subtleties if hasattr(cm,"ispop") and cm.ispop: fs = cm.population.frange_true if (fs_new is not None) and not np.array_equal(fs_new,cm.population.frange_true): with log_manager(logging.ERROR): PSD_interp = interp1d(fs,PSD,bounds_error=False,fill_value=0) PSD = PSD_interp(fs_new) fs = fs_new else: ## there is no way to compute the convolved injected spectra once the injected response functions have been flushed ## we have saved them, however, and can either just use the saved frozen spectra or interpolate to a new frequency grid ## WARNING: interpolation will likely result in low fidelity at f < 3e-4 Hz. with log_manager(logging.ERROR): PSD_interp = interp1d(fs,np.log10(PSD)) PSD = 10**PSD_interp(fs_new) fs = fs_new if return_fs: return fs, PSD else: return PSD
[docs] def plot_injected_spectra(self,component_name,fs_new=None,ax=None,convolved=False,legend=False,channels='11',return_PSD=False,scale='log',flim=None,ymins=None,**plt_kwargs): ''' Wrapper to plot the injected spectrum component on the specified matplotlib axes (or current axes if unspecified). Arguments ----------- component_name (str) : the name (key) of the Injection component to use. fs_new (array) : If desired, frequencies at which to interpolate the convolved PSD ax (matplotlib axes) : Axis on which to plot. Default None (will plot on current axes.) convolved (bool) : If True, convolve the injected spectra with the detector response. Default False. legend (bool) : If True, generate a legend entry. Default False. channels (str) : Which channel cross/auto-correlation PSD to plot. Default is '11' auto-correlation, i.e. XX for XYZ, 11 for Michelson, AA for AET. return_PSD (bool) : If True, also returns the plotted PSD. Default False. scale (str) : Matplotlib scale at which to plot ('log' or 'linear'). Default 'log'. flim (tuple) : (fmin,fmax) plot limits. Default None (will use fmin,fmax as specified in the params file.) ymins (list) : External list to which, if specified, will be added the lower ylim of the injected spectra. **plt_kwargs (kwargs) : matplotlib.pyplot keyword arguments Returns ----------- PSD plot on specified axes. PSD (array, optional) : Power spectral density of the specified channels' auto/cross-correlation at the desired frequencies. ''' ## grab component cm = self.components[component_name] ## set axes if ax is None: ax = plt.gca() ## set fmin/max to specified values, or default to the ones in params if flim is not None: fmin = flim[0] fmax = flim[1] else: fmin = self.params['fmin'] fmax = self.params['fmax'] ## special treatment of population frequencies # if hasattr(self.components[component_name],"ispop") and self.components[component_name].ispop: # fs_base = self.components[component_name].population.frange_true # else: # fs_base = self.frange ## get frozen injected spectra at original injection frequencies and convolve with detector response if desired if convolved: if component_name == 'noise': raise ValueError("Cannot convolve noise spectra with the detector GW response - this is not physical. (Set convolved=False in the function call!)") fs, PSD = self.compute_convolved_spectra(component_name,channels=channels,return_fs=True,fs_new=fs_new) else: ## handle wanting to plot at new frequencies (typically the data frequencies) ## original injection frequencies if fs_new is None: if hasattr(cm,"ispop") and cm.ispop: PSD = cm.population.Sgw_true fs = cm.population.frange_true else: fs = self.frange PSD = cm.frozen_spectra ## data frequencies (self.fdata in the code) elif (type(fs_new) is str) and (fs_new == 'data'): fs = cm.fdata PSD = cm.fdata_spectra ## estimate spectra at new frequencies -- WARNING: requires interpolation, usually produces low-fidelity results ## only really useful for quick checks and visualization, NOT for analysis purposes! else: if component_name == 'noise': fstar = 3e8/(2*np.pi*cm.armlength) f0_new = fs_new/(2*fstar) PSD = cm.instr_noise_spectrum(fs_new,f0_new,Np=10**cm.injvals['log_Np'],Na=10**cm.injvals['log_Na']) ## special treatment for the population case elif hasattr(cm,"ispop") and cm.ispop: PSD = cm.population.Sgw_true fs = cm.population.frange_true if not np.array_equal(fs_new,cm.population.frange_true): ## the interpolator gets grumpy sometimes, but it's not an actual issue hence the logging wrapper with log_manager(logging.ERROR): PSD_interp = interp1d(fs,PSD,bounds_error=False,fill_value=0) PSD = PSD_interp(fs_new) fs = fs_new else: Sgw_args = [cm.truevals[parameter] for parameter in cm.spectral_parameters] PSD = cm.compute_Sgw(fs_new,Sgw_args) fs = fs_new ## noise will return the 3x3 covariance matrix, need to grab the desired channel cross-/auto-power ## generically capture anything that looks like a covariance matrix for future-proofing if (len(PSD.shape)==3) and (PSD.shape[0]==PSD.shape[1]==3): I, J = int(channels[0]) - 1, int(channels[1]) - 1 PSD = PSD[I,J,:] filt = (fs>=fmin)*(fs<=fmax) if legend: label = cm.fancyname if plt_kwargs is None: plt_kwargs = {} plt_kwargs['label'] = label else: if 'label' not in plt_kwargs.keys(): plt_kwargs['label'] = label if scale=='log': ax.loglog(fs[filt],PSD[filt],**plt_kwargs) elif scale=='linear': ax.plot(fs[filt],PSD[filt],**plt_kwargs) else: raise ValueError("We only support linear and log plots, there is no secret third option!") if ymins is not None: ymins.append(PSD.min()) if return_PSD: return PSD else: return
[docs] def plot_skymaps(self,component_name,save_figures=True,return_mapdata=False,**plt_kwargs): ''' Function to plot the injected skymaps. NOTE - will need to be generalized when I add the astro injections ''' cm = self.components[component_name] # deals with projection parameter if self.params['projection'] is None: coord = 'E' elif self.params['projection']=='G' or self.params['projection']=='C': coord = ['E',self.params['projection']] elif self.params['projection']=='E': coord = self.params['projection'] else: raise TypeError('Invalid specification of projection, projection can be E, G, or C') if return_mapdata: cm_data = {} ## dimensionless energy density at 1 mHz spec_args = [cm.truevals[parameter] for parameter in cm.spectral_parameters] Omega_1mHz = cm.omegaf(1e-3,*spec_args) if hasattr(cm,"skymap"): Omegamap_pix = Omega_1mHz * cm.skymap/(np.sum(cm.skymap)*hp.nside2pixarea(self.params['nside'])/(4*np.pi)) ## tell healpy to shush with log_manager(logging.ERROR): hp.mollview(Omegamap_pix, coord=coord, title=r'Injected pixel map $\Omega (f = 1 mHz)$', unit=r"$\Omega(f= 1mHz)$", cmap=self.params['colormap']) hp.graticule() if save_figures: np.savetxt(self.params['out_dir']+'/inj_pixelmap_data.txt',Omegamap_pix) plt.savefig(self.params['out_dir'] + '/inj_pixelmap'+component_name+'.png', dpi=150) print('Saving injection pixel map at ' + self.params['out_dir'] + '/inj_pixelmap'+component_name+'.png') plt.close() if return_mapdata: cm_data['Omega_pixelmap'] = Omegamap_pix cm_data['normed_pixelmap'] = cm.skymap/(np.sum(cm.skymap)*hp.nside2pixarea(self.params['nside'])/(4*np.pi)) if hasattr(cm,"sph_skymap"): ## sph map Omegamap_inj = Omega_1mHz * cm.sph_skymap ## tell healpy to shush with log_manager(logging.ERROR): hp.mollview(Omegamap_inj, coord=coord, title=r'Injected angular distribution map $\Omega (f = 1 mHz)$', unit=r"$\Omega(f= 1mHz)$", cmap=self.params['colormap']) hp.graticule() if save_figures: plt.savefig(self.params['out_dir'] + '/inj_skymap'+component_name+'.png', dpi=150) print('Saving injected sph skymap at ' + self.params['out_dir'] + '/inj_skymap'+component_name+'.png') plt.close() if return_mapdata: cm_data['Omega_sphmap'] = Omegamap_inj cm_data['normed_sphmap'] = cm.sph_skymap if return_mapdata: return cm_data else: return
[docs] def extract_and_save_skymap_data(self,map_data_path=None): ## load or create plot_data dict if map_data_path is None: map_data_path = self.params['out_dir']+'/plot_data.pickle' if os.path.exists(map_data_path): with open(map_data_path, 'rb') as datafile: plot_data = pickle.load(datafile) if 'map_data' not in plot_data.keys(): plot_data['map_data'] = {} else: plot_data = {'map_data':{}} plot_data['map_data']['inj_maps'] = {} for cmn in self.component_names: if self.components[cmn].has_map: plot_data['map_data']['inj_maps'][cmn] = self.plot_skymaps(cmn,save_figures=False,return_mapdata=True) ## save map data if os.path.exists(map_data_path): ## move to temp file temp_file = map_data_path + ".temp" with open(temp_file, "wb") as datafile: pickle.dump(plot_data,datafile) shutil.move(temp_file, map_data_path) else: with open(map_data_path, 'wb') as datafile: plot_data = pickle.dump(plot_data,datafile) print("Data for injected skymaps saved to {}".format(map_data_path))
[docs] @jax.jit def bespoke_inv(A): """ compute inverse without division by det; ...xv3xc3 input, or array of matrices assumed Credit to Eelco Hoogendoorn at stackexchange for this piece of wizardy. This is > 3 times faster than numpy's det and inv methods used in a fully vectorized way as of numpy 1.19.1 https://stackoverflow.com/questions/21828202/fast-inverse-and-transpose-matrix-in-python """ AI = jnp.empty_like(A) for i in range(3): # AI[...,i,:] = jnp.cross(A[...,i-2,:], A[...,i-1,:]) AI = AI.at[...,i,:].set(jnp.cross(A[...,i-2,:], A[...,i-1,:])) ## jax version det = jnp.einsum('...i,...i->...', AI, A).mean(axis=-1) inv_T = AI / det[...,None,None] # inverse by swapping the inverse transpose return jnp.swapaxes(inv_T, -1,-2), det