Source code for blip.src.dynesty_engine

import numpy as np
from dynesty import NestedSampler
from dynesty.utils import resample_equal
import dill
import time
import shutil, os

[docs] class dynesty_engine(): ''' Class for interfacing with dynesty sampler. This method also contains the priors definition for all models written to work with the dynesty sampler. '''
[docs] @classmethod def define_engine(cls, lisaobj, params, nlive, nthread, randst, pool=None, resume=False): # create multiprocessing pool if nthread > 1: pool_size = nthread else: if pool is not None: print("Warning: Nthread=1 but pool has been defined. This shouldn't happen...") pool = None pool_size = None engine = NestedSampler(lisaobj.Model.likelihood, lisaobj.Model.prior, lisaobj.Model.Npar,\ bound='multi', sample=params['sample_method'], nlive=nlive, pool=pool, queue_size=pool_size, rstate = randst) # print npar print("Npar = " + str(lisaobj.Model.Npar)) return engine, lisaobj.Model.parameters
[docs] def load_engine(params,randst,pool): ## load engine from previous checkpoint ## nqueue is set to a negative number to trigger the queue to be refilled before the first iteration. ## randomstate cannot be saved, so we need to set that as well resume_file = params['out_dir']+'/checkpoint.pickle' if os.path.isfile(resume_file): print("Loading interrupted analysis from last checkpoint...") with open(resume_file,'rb') as file: engine, parameters = dill.load(file) if engine.added_live: engine._remove_live_points() engine.nqueue = -1 engine.rstate = randst if pool is not None: engine.pool = pool engine.loglikelihood.pool = pool engine.M = engine.pool.map else: engine.pool = None engine.loglikelihood.pool = None engine.M = map else: raise TypeError("Checkpoint file <{}> does not exist. Cannot resume from checkpoint.".format(resume_file)) return engine, parameters
[docs] @staticmethod def run_engine_with_checkpointing(engine,parameters,interval,checkpoint_file,step=1000): # -------------------- Run nested sampler --------------------------- pool = engine.pool old_ncall = engine.ncall start_time = time.time() while True: engine.run_nested(dlogz=0.5,maxiter=step,print_progress=True) if engine.ncall == old_ncall: break old_ncall = engine.ncall if os.path.isfile(checkpoint_file): last_checkpoint_s = time.time() - os.path.getmtime(checkpoint_file) else: last_checkpoint_s = time.time() - start_time if last_checkpoint_s > interval: print("Checkpointing...") ## pause the pool for saving if engine.pool is not None: engine.pool = None engine.M = map ## save if dill.pickles([engine,parameters]): temp_file = checkpoint_file + ".temp" with open(temp_file, "wb") as file: dill.dump([engine,parameters], file) shutil.move(temp_file, checkpoint_file) else: print("Warning: Cannot write checkpoint file, job cannot resume if interrupted.") ## restart pool if needed if pool is not None: engine.pool = pool engine.M = engine.pool.map ## removes live points lumped in with dead points when previous sampling call concluded ## unsure why dynesty requires this, but it seems to be deeply important if engine.added_live: engine._remove_live_points() # re-scale weights to have a maximum of one res = engine.results weights = np.exp(res['logwt'] - res['logz'][-1]) weights[-1] = 1 - np.sum(weights[0:-1]) post_samples = resample_equal(res.samples, weights) # Pull the evidence and the evidence error logz = res['logz'] logzerr = res['logzerr'] return post_samples, logz, logzerr
[docs] @staticmethod def run_engine(engine): # -------------------- Run nested sampler --------------------------- engine.run_nested(dlogz=0.5,print_progress=True ) # re-scale weights to have a maximum of one res = engine.results weights = np.exp(res['logwt'] - res['logz'][-1]) weights[-1] = 1 - np.sum(weights[0:-1]) post_samples = resample_equal(res.samples, weights) # Pull the evidence and the evidence error logz = res['logz'] logzerr = res['logzerr'] return post_samples, logz, logzerr