#!/bin/env python3
import pickle
import numpy as np
import sys, os, shutil, subprocess
# FIXME is this actually needed?
sys.path.append(os.getcwd()) ## this lets python find src
# FIXME why is there a src submodule?
from blip.src.makeLISAdata import LISAdata
from blip.src.models import Model, Injection
from blip.src.submodel import SubmodelKind
from blip.src.fast_geometry import fast_geometry, get_model_responses
from blip.src.faster_geometry import calculate_response_functions
from blip.src.utils import ensure_color_matching
from blip.tools.plotmaker import cornermaker
from blip.tools.plotmaker import mapmaker
from blip.tools.plotmaker import fitmaker
from blip.config import parse_config
import matplotlib.pyplot as plt
import matplotlib
from multiprocessing import Pool
import time
#import jax
from jax import config
config.update("jax_enable_x64", True)
## set default fonts
plot_params = {'font.family':'STIXGeneral',
'mathtext.fontset':'stix'
}
matplotlib.rcParams.update(plot_params)
[docs]
class LISA(LISAdata, Model):
'''
Generic class for getting data and setting up the prior space and likelihood.
'''
def __init__(self, params, inj):
# set up the LISAdata class
LISAdata.__init__(self, params, inj)
# Generate or get mldc or other loaded data
if self.params['load_data']:
self.process_external_data()
elif self.inj['loadInj']:
self.load_blip_injection_data()
else:
self.makedata()
## compute the cross/auto-channel correlations
self.make_data_correlation_matrix()
if not self.inj['inj_only']:
# Set up the Bayes class
print("Building Bayesian model...")
self.Model = Model(params,inj,self.fdata,self.f0,self.tsegmid,self.rmat)
## compute response functions
if params["faster_geometry"]:
submodels_sgwb = [sm for sm in self.Model.submodels.values() if sm.kind != SubmodelKind.NOISE]
calculate_response_functions(self.fdata, self.tsegmid, submodels_sgwb, params)
else:
get_model_responses(self.Model)
# make sure matching injections/models have matching colors
if not self.params['load_data']:
ensure_color_matching(self.Model,self.injection)
## if doing an injection, we also need the responses evaluated at the data frequencies for plotting/validation
if not self.params['load_data'] and not self.inj['loadInj']:
self.generate_inj_fdata_responses()
self.compute_inj_fdata_spectra()
# Make some simple diagnostic plots to contrast spectra
if self.params['load_data']:
self.plot_spectra()
elif not self.inj['loadInj']:
self.diag_spectra()
[docs]
def load_blip_injection_data(self):
'''
Function to load in data created via BLIP's own Injection routines, with associated Injection file.
This is useful because we save true values, frozen spectra, etc. in addition to the raw data stream,
so loading autogenerated data in this way allows BLIP's plotting and post-processing routines to
return significantly more informative outputs.
'''
print("Loading extant BLIP-generated data from {}".format(self.inj['injdir']))
with open(self.inj['injdir'] + '/injection.pickle', 'rb') as injectionfile:
self.injection = pickle.load(injectionfile)
loaded_data = np.load(self.inj['injdir']+'/simulated_data.npz')
self.timearray, self.h1, self.h2, self.h3 = loaded_data['timearray'], loaded_data['h1'], loaded_data['h2'], loaded_data['h3']
# Generate lisa freq domain data from time domain data
self.r1, self.r2, self.r3, self.fdata, self.tsegstart, self.tsegmid = self.tser2fser(self.h1, self.h2, self.h3, self.timearray)
# Charactersitic frequency. Define f0
cspeed = 3e8
fstar = cspeed/(2*np.pi*self.armlength)
self.f0 = self.fdata/(2*fstar)
[docs]
def makedata(self):
'''
Just a wrapper function to use the methods the LISAdata class
to generate data. Return Frequency domain data.
'''
# Simulation time-frequency grid precomputed during configuration parsing
nsplice = self.params["nsplice"]
tsegmid = self.params["tsegmid"]
Npersplice = self.params["Npersplice"]
## leave out f = 0
frange = np.fft.rfftfreq(Npersplice, 1.0/self.params['fs'])[1:]
## the charecteristic frequency of LISA, and the scaled frequency array
fstar = 3e8/(2*np.pi*self.armlength)
f0 = frange/(2*fstar)
## Build the Injection object
print("Constructing injection...")
self.injection = Injection(self.params,self.inj,frange,f0,tsegmid)
## assign a couple additional universal injection attributes needed in add_sgwb_data()
self.injection.Npersplice = Npersplice
self.injection.nsplice = nsplice
# Generate TDI noise
times, self.h1, self.h2, self.h3 = self.injection.components['noise'].gen_noise_spectrum()
delt = times[1] - times[0]
# Cut to required size
N = int((self.params['dur'])/delt)
self.timearray = np.arange(0, self.params['dur'], delt)
self.timearray = self.timearray[0:N]
self.h1, self.h2, self.h3 = self.h1[0:N], self.h2[0:N], self.h3[0:N]
## create time-domain contribution from each injection component that isn't noise
print("Simulating SGWB contributions...")
## set branching keys to ensure multi-sgwb independence
if 'seed' in self.params.keys():
rng_seed = np.int32(self.params['seed'])
else:
rng_seed = np.random.randint(int(1e9),dtype='int32')
rng = np.random.default_rng(rng_seed)
keys = rng.integers(int(1e9),size=len(self.injection.sgwb_component_names),dtype='int32')
for component, key in zip(self.injection.sgwb_component_names,keys):
h1_gw, h2_gw, h3_gw, times = self.add_sgwb_data(self.injection.components[component],key)
assert h1_gw.size == times.size
assert h1_gw.size == self.h1.size
assert np.isclose(delt, times[1]-times[0]), 'The noise and signal arrays are at different sampling frequencies!'
# cut again to required length, because the signal (constructed from splices)
# might be longer or shorter than the noise.
N = min(N, len(times))
self.timearray = self.timearray[0:N]
times = times[0:N]
self.h1, self.h2, self.h3 = self.h1[0:N], self.h2[0:N], self.h3[0:N]
h1_gw, h2_gw, h3_gw = h1_gw[0:N], h2_gw[0:N], h3_gw[0:N]
# Add gravitational-wave time series to noise time-series
self.h1 = self.h1 + h1_gw
self.h2 = self.h2 + h2_gw
self.h3 = self.h3 + h3_gw
# Desample if we increased the sample rate for time-shifts.
if self.params['fs'] != 1.0/delt:
self.params['fs'] = 1.0/delt
# Generate lisa freq domain data from time domain data
self.r1, self.r2, self.r3, self.fdata, self.tsegstart, self.tsegmid = self.tser2fser(self.h1, self.h2, self.h3, self.timearray)
# Charactersitic frequency. Define f0
cspeed = 3e8
fstar = cspeed/(2*np.pi*self.armlength)
self.f0 = self.fdata/(2*fstar)
[docs]
def make_data_correlation_matrix(self):
'''
Uses the generated time-domain data series to construct a data correlation matrix.
Used to be the initialization of the (now defunct) likelihoods.py
'''
self.rbar = np.stack((self.r1, self.r2, self.r3), axis=2)
## create a data correlation matrix
self.rmat = np.zeros((self.rbar.shape[0], self.rbar.shape[1], self.rbar.shape[2], self.rbar.shape[2]), dtype='complex')
for ii in range(self.rbar.shape[0]):
for jj in range(self.rbar.shape[1]):
self.rmat[ii, jj, :, :] = np.tensordot(np.conj(self.rbar[ii, jj, :]), self.rbar[ii, jj, :], axes=0 )
[docs]
def generate_inj_fdata_responses(self):
'''
A function to generate the response functions for each injection at the simulated data frequencies.
This is needed to allow us to accurately plot and compare injected spectra vs. simulated data vs. recovered models.
'''
#Charactersitic frequency
fstar = 3e8/(2*np.pi*self.armlength)
# define f0 = f/2f*
f0_data = self.fdata/(2*fstar)
submodels_sgwb = [self.injection.components[cmn] for cmn in self.injection.sgwb_component_names]
if self.params["faster_geometry"]:
calculate_response_functions(self.fdata, self.tsegmid, submodels_sgwb, self.params, plot_flag=True)
else:
fast_rx = fast_geometry(self.params)
fast_rx.calculate_response_functions(f0_data,self.tsegmid,submodels_sgwb,self.params['tdi_lev'],plot_flag=True)
## deallocate to save on memory now that the response functions have been calculated and stored elsewhere
del fast_rx ## deallocate to save on memory now that the response functions have been calculated and stored elsewhere
return
[docs]
def compute_inj_fdata_spectra(self):
'''
A function to compute the injected spectra as evaluated at the simulated data frequencies.
This is needed to allow us to accurately plot and compare injected spectra vs. simulated data vs. recovered models.
'''
for cmn in self.injection.component_names:
cm = self.injection.components[cmn]
#Charactersitic frequency
fstar = 3e8/(2*np.pi*self.armlength)
# define f0 = f/2f*
f0_data = self.fdata/(2*fstar)
cm.fdata = self.fdata
## note that for astrophysical signals we take the time-average but keep the full 3x3 correlation matrix
if cmn == 'noise':
Np, Na = 10**cm.injvals['log_Np'], 10**cm.injvals['log_Na']
cm.fdata_noise = cm.instr_noise_spectrum(self.fdata,f0_data, Np, Na)
elif hasattr(cm,"ispop") and cm.ispop:
cm.fdata_spectra = cm.population.Sgw_true
cm.fdata_convolved_spectra = np.mean(cm.fdata_spectra[None,None,:,None] * cm.fdata_response_mat,axis=-1)
else:
spec_args = [cm.truevals[parameter] for parameter in cm.spectral_parameters]
cm.fdata_spectra = cm.compute_Sgw(self.fdata,spec_args)
cm.fdata_convolved_spectra = np.mean(cm.fdata_spectra[None,None,:,None] * cm.fdata_response_mat,axis=-1)
return
[docs]
def diag_spectra(self):
'''
A function to do simple diagnostics. Plot the expected spectra and data.
'''
# ------------ Calculate PSD ------------------
# PSD from the FFTs
data_PSD1, data_PSD2, data_PSD3 = np.mean(np.abs(self.r1)**2, axis=1), np.mean(np.abs(self.r2)**2, axis=1), np.mean(np.abs(self.r3)**2, axis=1)
# "Cut" to desired frequencies
idx = np.logical_and(self.fdata >= self.params['fmin'] , self.fdata <= self.params['fmax'])
psdfreqs = self.fdata[idx]
# Get desired frequencies for the PSD
# We want to normalize PSDs to account for the windowing
# Also convert from doppler-shift spectra to strain spectra
data_PSD1,data_PSD2, data_PSD3 = data_PSD1[idx], data_PSD2[idx], data_PSD3[idx]
## Get the noise component
cmn_noise = self.injection.components['noise']
C_noise = cmn_noise.fdata_noise
# Extract noise auto-power
S1, S2, S3 = C_noise[0, 0, :], C_noise[1, 1, :], C_noise[2, 2, :]
## compute and save the time-averaged response-convolved spectrum for plotting, etc.
for cmn in self.injection.sgwb_component_names:
cm = self.injection.components[cmn]
if hasattr(cm,"ispop") and cm.ispop:
Sgw_convolved = cm.fdata_convolved_spectra ## the population spectra needs to be binned at the observed (data) frequencies in any case
else:
Sgw_convolved = np.mean(cm.frozen_spectra[None,None,:,None] * cm.inj_response_mat,axis=-1)
cm.frozen_convolved_spectra = Sgw_convolved
## prep saving original PSDs
## load or create plot_data dict
plot_data_path = self.params['out_dir']+'/plot_data.pickle'
if os.path.exists(plot_data_path):
with open(plot_data_path, 'rb') as datafile:
plot_data = pickle.load(datafile)
if 'injspec_data' not in plot_data.keys():
plot_data['injspec_data'] = {}
else:
plot_data = {'injspec_data':{}}
## add the full-spectrum simulated data fft to the plotdata dict
plot_data['injspec_data']['data_psd'] = {'psdfreqs':psdfreqs,'psd1':data_PSD1,'psd2':data_PSD2,'psd3':data_PSD3}
## also the noise contributions
plot_data['injspec_data']['noise_spectra'] = {'noise1':S1,'noise2':S2,'noise3':S3}
plt.close()
# noise multichannel plots
ffilt = (self.fdata>=self.params['fmin'])*(self.fdata<=self.params['fmax'])
# plt.loglog(psdfreqs, data_PSD1,label='Simulated Data Series PSD', alpha=0.6, lw=0.75,color='slategrey')
if self.params['tdi_lev'] == 'michelson':
channel_labels = ['1','2','3']
elif self.params['tdi_lev'] == 'xyz':
channel_labels = ['X','Y','Z']
elif self.params['tdi_lev'] == 'aet':
channel_labels = ['A','E','T']
plot_data['injspec_data']['channels'] = channel_labels # FIXME used before assignment
channel_colors = ['black','darkgrey','steelblue']
## flag if we need to look at CSDs also
cross = False
if cross:
fig, (ax1,ax2) = plt.subplots(1,2,figsize=(15,7))
ax_list = [ax1,ax2]
else:
plt.figure()
ax1 = plt.gca()
ax_list = [ax1]
ax1.set_title('Instrumental Noise Autopower Spectral Density')
for i in range(3):
ax1.loglog(self.fdata[ffilt], C_noise[i, i, :][ffilt], label=channel_labels[i]+channel_labels[i], lw=0.75,color=channel_colors[i])
ax1.legend()
if cross:
ax2.set_title('Instrumental Noise Crosspower Spectral Density')
for i,j,k in zip([0,0,1],[1,2,2],[0,1,2]):
ax2.loglog(self.fdata[ffilt], C_noise[i, j, :][ffilt], label=channel_labels[i]+channel_labels[j], lw=0.75,color=channel_colors[k])
ax2.legend()
for ax in ax_list:
ax.set_xlabel('$f$ in Hz')
ax.set_ylabel('PSD 1/Hz ')
ax.set_xlim(self.params['fmin'], self.params['fmax'])
plt.tight_layout()
plt.savefig(self.params['out_dir'] + '/noise_multichannel.png', dpi=200)
print('Diagnostic noise spectra plot made in ' + self.params['out_dir'] + '/noise_multichannel.png')
plt.close()
## psd budget plot
ffilt = (self.fdata>=self.params['fmin'])*(self.fdata<=self.params['fmax'])
plt.loglog(psdfreqs, data_PSD1,label='Simulated Data Series PSD', alpha=0.6, lw=0.75,color='slategrey')
plt.loglog(self.fdata[ffilt], C_noise[0, 0, :][ffilt], label='Instrumental Noise ({} Autopower)'.format(channel_labels[0]+channel_labels[0]), lw=0.75,color='dimgrey')
ymins = []
ymeds = []
ywmeds = []
ydevs = []
ylim_flag = False
for component_name in self.injection.sgwb_component_names:
S1_gw = self.injection.plot_injected_spectra(component_name,fs_new='data',convolved=True,legend=True,channels='11',return_PSD=True,lw=0.75,color=self.injection.components[component_name].color)
S1_gw_filt = S1_gw[ffilt]
plot_data['injspec_data'][component_name] = {'S1_gw':S1_gw_filt}
plot_data['injspec_data'][component_name]['f_filt'] = self.fdata[ffilt]
log_S1_gw = np.log10(S1_gw_filt[np.nonzero(S1_gw_filt)])
log_bin_sizes = np.log10(self.fdata[ffilt][np.nonzero(S1_gw_filt)])[1:] - np.log10(self.fdata[ffilt][np.nonzero(S1_gw_filt)])[:-1]
ave_bin_val = (log_S1_gw[1:] + log_S1_gw[:-1])/2
weighted_ymed = np.sum(log_bin_sizes*ave_bin_val/np.sum(log_bin_sizes))
ymins.append(np.min(log_S1_gw))
ymeds.append(np.median(log_S1_gw))
ywmeds.append(weighted_ymed)
ydevs.append(np.std(log_S1_gw))
ylim_flag = True
S2_gw, S3_gw = self.injection.compute_convolved_spectra(component_name,fs_new='data',channels='22'), self.injection.compute_convolved_spectra(component_name,fs_new='data',channels='33')
S1, S2, S3 = S1+S1_gw, S2+S2_gw, S3+S3_gw
plt.loglog(self.fdata[ffilt], S1[ffilt], label='Simulated Total spectrum', lw=0.75,color='cadetblue')
## set plot limits, with dynamic scaling for the y-axis to handle spectra with cutoffs, etc.
if ylim_flag:
ylows = [ywmed_i - ydev_i for ywmed_i,ydev_i in zip(ywmeds,ydevs)]
ylow_min = np.min(ylows)
plt.ylim(bottom=10**(ylow_min-1))
## save for plotting the results with plotmaker
self.injection.plot_ymin = 10**(ylow_min-1)
plt.legend(loc='upper right')
plt.xlabel('$f$ in Hz')
plt.ylabel('PSD 1/Hz ')
plt.xlim(self.params['fmin'], self.params['fmax'])
plt.savefig(self.params['out_dir'] + '/psd_budget.png', dpi=200)
print('Diagnostic spectra plot made in ' + self.params['out_dir'] + '/psd_budget.png')
plt.close()
plt.loglog(self.fdata, S3, label='required',color='mediumvioletred')
plt.loglog(psdfreqs, data_PSD3,label='PSD, data', alpha=0.6,color='slategrey')
plt.xlabel('$f$ in Hz')
plt.ylabel('PSD 1/Hz ')
plt.legend()
plt.grid(linestyle=':',linewidth=0.5 )
plt.xlim(self.params['fmin'], self.params['fmax'])
## avoid plot squishing due to signal spectra with cutoffs, etc.
if len(ymins) > 0:
ymin = np.min(ymins)
if ymin < 1e-43:
plt.ylim(bottom=1e-43)
plt.savefig(self.params['out_dir'] + '/diag_psd.png', dpi=200)
print('Diagnostic spectra plot made in ' + self.params['out_dir'] + '/diag_psd.png')
plt.close()
## lets also plot psd residue.
rel_res_mean = (data_PSD3 - S3)/S3
plt.semilogx(self.fdata, rel_res_mean , label='relative mean residue',color='slategrey')
plt.xlabel('f in Hz')
plt.ylabel(' Rel. residue')
plt.ylim([-1.50, 1.50])
plt.legend()
plt.grid()
plt.xlim(self.params['fmin'], self.params['fmax'])
plt.savefig(self.params['out_dir'] + '/res_psd.png', dpi=200)
print('Residue spectra plot made in ' + self.params['out_dir'] + '/res_psd.png')
plt.close()
#
# cross-power diag plots. We will only do 12. IF TDI=XYZ this is S_XY and if TDI=AET
# this will be S_AE
ii, jj = 2,0
IJ = str(ii+1)+str(jj+1)
Sx = C_noise[ii,jj,:]
ymins = []
iymins = []
for component_name in self.injection.sgwb_component_names:
if component_name != 'noise':
Sx_gw = self.injection.compute_convolved_spectra(component_name,fs_new='data',channels=IJ) + self.injection.compute_convolved_spectra(component_name,fs_new='data',channels=IJ,imaginary=True)
ymins.append(np.real(Sx_gw).min())
iymins.append(np.imag(Sx_gw).min())
Sx = Sx + Sx_gw
CSDx = np.mean(np.conj(self.rbar[:, :, ii]) * self.rbar[:, :, jj], axis=1)
plt.subplot(2, 1, 1)
if len(Sx.shape) == 1:
plt.loglog(self.fdata, np.abs(np.real(Sx)), label='Re(Required ' + str(ii+1) + str(jj+1) + ')',color='mediumvioletred')
else:
plt.loglog(self.fdata, np.mean(np.abs(np.real(Sx)),axis=1), label='Re(Required ' + str(ii+1) + str(jj+1) + ')',color='mediumvioletred')
plt.loglog(psdfreqs, np.abs(np.real(CSDx)) ,label='Re(CSD' + str(ii+1) + str(jj+1) + ')', alpha=0.6,color='slategrey')
plt.xlabel('f in Hz')
plt.ylabel('Power in 1/Hz')
plt.legend()
plt.ylim([1e-44, 5e-40])
plt.xlim(self.params['fmin'], self.params['fmax'])
plt.grid()
plt.subplot(2, 1, 2)
if len(Sx.shape) == 1:
plt.loglog(self.fdata, np.abs(np.imag(Sx)), label='Im(Required ' + str(ii+1) + str(jj+1) + ')',color='mediumvioletred')
else:
plt.loglog(self.fdata, np.mean(np.abs(np.imag(Sx)),axis=1), label='Im(Required ' + str(ii+1) + str(jj+1) + ')',color='mediumvioletred')
plt.loglog(psdfreqs, np.abs(np.imag(CSDx)) ,label='Im(CSD' + str(ii+1) + str(jj+1) + ')', alpha=0.6,color='slategrey')
plt.xlabel('f in Hz')
plt.ylabel(' Power in 1/Hz')
plt.legend()
plt.xlim(self.params['fmin'], self.params['fmax'])
plt.ylim([1e-44, 5e-40])
plt.grid()
plt.savefig(self.params['out_dir'] + '/diag_csd_' + str(ii+1) + str(jj+1) + '.png', dpi=200)
print('Diagnostic spectra plot made in ' + self.params['out_dir'] + '/diag_csd_' + str(ii+1) + str(jj+1) + '.png')
plt.close()
## save fit data
if os.path.exists(plot_data_path):
## move to temp file
temp_file = plot_data_path + ".temp"
with open(temp_file, "wb") as datafile:
pickle.dump(plot_data,datafile)
shutil.move(temp_file, plot_data_path)
else:
with open(plot_data_path, 'wb') as datafile:
plot_data = pickle.dump(plot_data,datafile)
[docs]
def plot_spectra(self):
'''
A function to make a plot of the data spectrum. For use with external (non-autogenerated) data, where we cannot calculate the intrinsic components.
'''
# PSD from the FFTs
data_PSD1, data_PSD2, data_PSD3 = np.mean(np.abs(self.r1)**2, axis=1), np.mean(np.abs(self.r2)**2, axis=1), np.mean(np.abs(self.r3)**2, axis=1)
# "Cut" to desired frequencies
idx = np.logical_and(self.fdata >= self.params['fmin'] , self.fdata <= self.params['fmax'])
psdfreqs = self.fdata[idx]
# Get desired frequencies for the PSD
data_PSD1,data_PSD2, data_PSD3 = data_PSD1[idx], data_PSD2[idx], data_PSD3[idx]
plt.loglog(psdfreqs, data_PSD1,label='PSD (1)', alpha=0.6, color='slategrey')
plt.loglog(psdfreqs, data_PSD2,label='PSD (2)', alpha=0.6, color='rosybrown')
plt.loglog(psdfreqs, data_PSD3,label='PSD (3)', alpha=0.6, color='mediumseagreen')
plt.xlabel('$f$ in Hz')
plt.ylabel('PSD 1/Hz ')
plt.legend()
plt.grid(linestyle=':',linewidth=0.5 )
plt.xlim(self.params['fmin'], self.params['fmax'])
plt.savefig(self.params['out_dir'] + '/data_psd.png', dpi=200)
print('Data spectra plot made in ' + self.params['out_dir'] + '/data_psd.png')
plt.close()
[docs]
def run_pipeline(parsed_params, resume, pre_sample_hook=None):
"""Run the Bayesian pipeline.
Parameters
----------
parsed_params : tuple
Parameters for the run as output by :func:`parse_config`.
resume : bool
Whether to resume a checkpointed run.
pre_sample_hook : Callable, optional
Function to call on the analysis model before sampling. Can be used to perform
arbitrary changes to the model. By default None
"""
params, inj, misc = parsed_params
nthread = misc["nthread"]
randst = misc["randst"]
nlive = misc["nlive"]
N_GPU = misc["N_GPU"]
if not resume:
# Make directories, copy stuff
# Make output folder
os.makedirs(params['out_dir'], exist_ok=True)
# Copy the params file to outdir, to keep track of the parameters of each run.
path_paramsfile = os.path.join(params['out_dir'], misc['paramsfile_name'])
with open(path_paramsfile, "w") as f:
f.write(misc['paramsfile_text'])
# Initialize lisa class
lisa = LISA(params, inj)
## save the Model and Injection as needed
## the Injection is massive, so discard the responses we no longer need
## saving & discarding now, as opposed to at the end of the run also saves space in the checkpoint files.
if not params['load_data']:
## save Injection
for cmn in lisa.injection.component_names:
if hasattr(lisa.injection.components[cmn],'response_mat'):
del lisa.injection.components[cmn].response_mat
if hasattr(lisa.injection.components[cmn],'summ_response_mat'):
del lisa.injection.components[cmn].summ_response_mat
if hasattr(lisa.injection.components[cmn],'inj_response_mat'):
del lisa.injection.components[cmn].inj_response_mat
with open(params['out_dir'] + '/injection.pickle', 'wb') as outfile:
pickle.dump(lisa.injection, outfile)
print("Saved Injection to "+params['out_dir']+"/injection.pickle")
## also save the injected skymaps for later use
lisa.injection.extract_and_save_skymap_data(map_data_path=params['out_dir']+"/plot_data.pickle")
## save generated data
np.savez_compressed(params['out_dir']+'/simulated_data.npz',timearray=lisa.timearray,h1=lisa.h1,h2=lisa.h2,h3=lisa.h3)
print("Saved strain time series to "+params['out_dir']+"/simulated_data.npz")
## Data generation is complete. Exit if only performing injection and data simulation.
if inj['inj_only']:
print("Simulation and generation of LISA data is complete and inj_only flag is set to True. Saving configuration and exiting...")
# Save parameters as a pickle
with open(params['out_dir'] + '/config.pickle', 'wb') as outfile:
pickle.dump(params, outfile)
pickle.dump(inj, outfile)
return
## the Model tends to be more lightweight
with open(params['out_dir'] + '/model.pickle', 'wb') as outfile:
pickle.dump(lisa.Model, outfile)
print("Saved Model to "+params['out_dir']+"/model.pickle")
# Also save the config so we can load and use our simulated data if the run fails during sampling
with open(params['out_dir'] + '/config.pickle', 'wb') as outfile:
pickle.dump(params, outfile)
pickle.dump(inj, outfile)
print("Generating sampling engine...")
else:
print("Resuming a previous analysis. Reloading data and sampler state...")
# run user code to modify the model before sampling
if pre_sample_hook is not None:
print("Pre-sample hook was provided. Running hook...")
lisa.Model = pre_sample_hook(lisa.Model, parsed_params)
print("Done running pre-sample hook.")
if params['sampler'] == 'dynesty':
from blip.src.dynesty_engine import dynesty_engine
# Create engine
if not resume:
# multiprocessing
if nthread > 1:
pool = Pool(nthread)
else:
pool = None
engine, parameters = dynesty_engine().define_engine(lisa, params, nlive, nthread, randst, pool=pool)
else:
pool = None
if nthread > 1:
print("Warning: Nthread > 1, but multiprocessing is not supported when resuming a run. Pool set to None.")
## To anyone reading this and wondering why:
## The pickle calls used by Python's multiprocessing fail when trying to run the sampler after saving/reloading it.
## This is because pickling the sampler maps all its attributes to their full paths;
## e.g., dynesty_engine.isgwb_prior is named as src.dynesty_engine.dynesty_engine.isgwb_prior
## BUT the object itself is still e.g. <function dynesty_engine.isgwb_prior at 0x7f8ebcc27130>
## so we get an error like
## _pickle.PicklingError: Can't pickle <function dynesty_engine.isgwb_prior at 0x7f8ebcc27130>: \
## it's not the same object as src.dynesty_engine.dynesty_engine.isgwb_prior
## See e.g. https://stackoverflow.com/questions/1412787/picklingerror-cant-pickle-class-decimal-decimal-its-not-the-same-object
## After too much time and sanity spent trying to fix this, I have admitted defeat.
## Feel free to try your hand -- maybe you're the chosen one. Good luck.
engine, parameters = dynesty_engine.load_engine(params,randst,pool)
## run sampler
if params['checkpoint']:
checkpoint_file = params['out_dir']+'/checkpoint.pickle'
t1 = time.time()
post_samples, logz, logzerr = dynesty_engine.run_engine_with_checkpointing(engine,parameters,params['checkpoint_interval'],checkpoint_file,step=200)
t2= time.time()
print("Elapsed time to converge: {} s".format(t2-t1))
else:
t1 = time.time()
post_samples, logz, logzerr = dynesty_engine.run_engine(engine)
t2= time.time()
print("Elapsed time to converge: {} s".format(t2-t1))
if nthread > 1:
engine.pool.close()
engine.pool.join()
# Save posteriors to file
np.savetxt(params['out_dir'] + "/post_samples.txt",post_samples)
np.savetxt(params['out_dir'] + "/logz.txt", logz)
np.savetxt(params['out_dir'] + "/logzerr.txt", logzerr)
elif params['sampler'] == 'emcee':
from blip.src.emcee_engine import emcee_engine
# multiprocessing
if nthread>1:
pool=Pool(nthread)
else:
pool=None
# Create engine
engine, parameters, init_samples = emcee_engine.define_engine(lisa.Model, nlive, randst, pool=pool)
unit_samples, post_samples = emcee_engine.run_engine(engine, lisa.Model, init_samples,params['Nburn'],params['Nsamples'])
# Save posteriors to file
np.savetxt(params['out_dir'] + "/unit_samples.txt",unit_samples)
np.savetxt(params['out_dir'] + "/post_samples.txt",post_samples)
elif params['sampler'] == 'numpyro':
if nthread > 1:
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count={}'.format(nthread)
from blip.src.numpyro_engine import numpyro_engine
## create engine
Ntotal = params['Nsamples']
if not resume:
## without checkpointing, set up the engine to run for Nsamples samples
## with checkpointing, set it up to run for checkpoint_interval samples (and set the starting chain to None)
if params['checkpoint'] and params['checkpoint_at']=='interval':
Nrun = params['checkpoint_interval']
else:
Nrun = Ntotal
## set initial chain to None
chain = None
if N_GPU > 0:
use_gpu = True
else:
use_gpu = False
engine, parameters, rng_key = numpyro_engine.define_engine(lisa, params['Nburn'], Nrun, nthread, params['show_progress'], params['seed'], gpu=use_gpu)
lisaModel = lisa.Model
## remove the functions which calculate responses so these aren't traced by JAX
if hasattr(lisaModel,"calculate_response_functions"):
delattr(lisaModel,"calculate_response_functions")
else:
resume_file = params['out_dir']+'/checkpoint.pickle'
engine, rng_key, chain = numpyro_engine.load_engine(resume_file)
with open(params['out_dir'] + '/model.pickle', 'rb') as modfile:
lisaModel = pickle.load(modfile)
parameters = lisaModel.parameters
## remove the functions which calculate responses so these aren't traced by JAX
if hasattr(lisaModel,"calculate_response_functions"):
print("got here")
delattr(lisaModel,"calculate_response_functions")
## check to see whether the run has already finished
if chain is None:
N_current = 0
else:
N_current = len(chain['theta_transformed'][0])
## additional samples
if params['additional_samples'] is not None:
if N_current < params['Nsamples']:
print("WARNING: {} additional samples were requested, but the original run isn't finished ({}/{} samples remaining)".format(params['additional_samples'],params['Nsamples'],params['Nsamples']))
print("WARNING: Disregarding and continuing the original run...")
else:
print("Run was previously completed and {} additional samples have been requested. Resuming chain sampling from final state...".format(params['additional_samples']))
## new total = old total + additional
Ntotal = N_current + params['additional_samples']
## if not sampling in N=checkpoint_interval chunks, set the requested # of samples
if params['checkpoint_at'] != 'interval':
engine.num_samples = params['additional_samples']
## if we're not doing additional samples and the run is over, raise a helpful error.
elif N_current==params['Nsamples']:
raise ValueError("Run was resumed but has already finished. If you want to continue adding samples to the chain, specify additional_samples=[N desired] in the run directory params file before resuming.")
## run sampler
if params['checkpoint']:
checkpoint_file = params['out_dir']+'/checkpoint.pickle'
t1 = time.time()
post_samples = numpyro_engine.run_engine_with_checkpointing(engine,lisaModel,rng_key,chain,checkpoint_file,Ntotal,params['checkpoint_at'],resume=resume)
t2= time.time()
print("Elapsed time to converge: {} s".format(t2-t1))
else:
t1 = time.time()
post_samples = numpyro_engine.run_engine(engine,lisa,rng_key)
t2= time.time()
print("Elapsed time to converge: {} s".format(t2-t1))
## save chain
np.savetxt(params['out_dir'] + "/post_samples.txt",post_samples)
# FIXME either make this work or remove this code
# elif params['sampler'] == 'numpyro_nested':
# ## create engine
# Ntotal = params['Nsamples']
# if not resume:
#
# ## set initial chain to None
# chain = None
#
# if N_GPU > 0:
# use_gpu = True
# else:
# use_gpu = False
# engine, parameters, rng_key = numpyro_nested_engine.define_engine(lisa, Ntotal, params['seed'], gpu=use_gpu)
# lisaModel = lisa.Model
# ## remove the functions which calculate responses so these aren't traced by JAX
# if hasattr(lisaModel,"calculate_response_functions"):
# delattr(lisaModel,"calculate_response_functions")
# else:
# resume_file = params['out_dir']+'/checkpoint.pickle'
# engine = numpyro_nested_engine.load_engine(resume_file)
# with open(params['out_dir'] + '/model.pickle', 'rb') as modfile:
# lisaModel = pickle.load(modfile)
# parameters = lisaModel.parameters
#
# ## additional samples
# if params['additional_samples'] is not None:
# ## new total = old total + additional
# try:
# old_chain = np.loadtxt(params['out_dir'] + "/post_samples.txt")
# except:
# raise ValueError("There is no posterior samples file (post_samples.txt) in the output directory. Did you complete the original run?")
#
# print("Run was previously completed and {} additional samples have been requested. Resampling weighted sample chain...".format(params['additional_samples']))
# Ntotal = N_current + params['additional_samples']
# engine.num_samples = params['additional_samples']
# _, _, new_key = jax.random.PRNGKey(seed).split(3)
# additional_post_samples = np.array(engine.get_samples(new_key,num_samples=params['additional_samples'])['theta_transformed']).T
# post_samples = np.concatenate((old_chain,additional_post_samples))
#
# ## run sampler
# checkpoint_file = params['out_dir']+'/checkpoint.pickle'
# t1 = time.time()
# post_samples = numpyro_nested_engine.run_engine(engine,lisaModel,rng_key,checkpoint_file)
# t2= time.time()
# print("Elapsed time to converge: {} s".format(t2-t1))
#
# ## save chain
# np.savetxt(params['out_dir'] + "/post_samples.txt",post_samples)
else:
raise TypeError('Unknown sampler model chosen. Only dynesty, numpyro, & emcee are supported')
#
## Safely re-save the config pickle, now with the model parameters
temp_file = params['out_dir'] + '/config.pickle' + ".temp"
with open(temp_file, "wb") as outfile:
pickle.dump(params, outfile)
pickle.dump(inj, outfile)
pickle.dump(parameters, outfile)
shutil.move(temp_file, params['out_dir'] + '/config.pickle')
print("\nMaking posterior Plots ...")
## reload the Model and Injection if needed
if not resume:
plotting_Model = lisa.Model
plotting_Injection = lisa.injection
else:
## grab the model and injection
with open(params['out_dir'] + '/model.pickle', 'rb') as modelfile:
plotting_Model = pickle.load(modelfile)
if not params['load_data']:
with open(params['out_dir'] + '/injection.pickle', 'rb') as injectionfile:
plotting_Injection = pickle.load(injectionfile)
## reset the matplotlib style setting
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
if not params['load_data']:
cornermaker(post_samples, params, parameters, inj, plotting_Model, Injection=plotting_Injection)
# FIXME plotting_Injection possibly used before assignment
else:
cornermaker(post_samples, params, parameters, inj, plotting_Model)
if plotting_Model.Npar >= 10:
print("\n")
print("WARNING: High Model N_parameters detected ({}). Corner plot may be crowded. Try running plotmaker.py and specifying the 'cornersplit' argument via".format(plotting_Model.Npar))
print("python3 ./blip/tools/plotmaker.py {} --cornersplit type".format(params['out_dir']))
print("or")
print("python3 ./blip/tools/plotmaker.py {} --cornersplit submodel".format(params['out_dir']))
print("\n")
## reset the matplotlib style setting
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
if not params['load_data']:
fitmaker(post_samples, params, parameters, inj, plotting_Model, plotting_Injection)
else:
fitmaker(post_samples, params, parameters, inj, plotting_Model)
## reset the matplotlib style setting
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
## make a map if there is a map to be made
if np.any([plotting_Model.submodels[sm_name].has_map for sm_name in plotting_Model.submodel_names]):
if 'healpy_proj' in params.keys():
mapmaker(post_samples, params, parameters, plotting_Model, coord=params['healpy_proj'], cmap=params['colormap'])
else:
mapmaker(post_samples, params, parameters, plotting_Model, cmap=params['colormap'])
[docs]
def blip(paramsfile, *, resume):
"""Run BLIP on a given parameter file with given command-line options.
Parameters
----------
paramsfile : str
Path to INI parameter file.
resume : bool
CLI option to resume from a checkpointed run.
"""
parsed_params = parse_config(paramsfile, resume)
run_pipeline(parsed_params, resume)
[docs]
def main():
if len(sys.argv) != 2:
if sys.argv[2] == "resume":
blip(sys.argv[1], resume=True)
else:
raise ValueError("Provide (only) the params file as an argument")
else:
blip(sys.argv[1], resume=False)
if __name__ == "__main__":
main()