Source code for blip.src.numpyro_engine

import numpy as np
import jax.numpy as jnp
import jax
import numpyro
from numpyro.infer import MCMC, NUTS
#from numpyro.contrib.nested_sampling import NestedSampler
from numpyro.distributions import constraints
import numpyro.distributions as dist
import pickle, dill
import os
import shutil


[docs] def numpyro_model(Model): ''' Wrapper to translate our unified prior and likelihood to Numpyro-friendly input. Arguments ------------- parameters (list of str) : Model parameter name list prior_transform (function) : prior transform from unit cube log_likelihood (function) : desired log likelihood function, should take in theta parameter vector (we will do live adaptation to numpyro format below) ''' with numpyro.plate("Npar",Model.Npar): theta = numpyro.sample('theta',dist.Uniform(0,1)) theta_trans = numpyro.deterministic("theta_transformed", Model.prior(theta)) numpyro.factor('loglike',log_factor=Model.likelihood(theta_trans))
[docs] def numpyro_model_sph(Model): ''' Wrapper to translate our unified prior and likelihood to Numpyro-friendly input. The _sph version here accounts for the fact that the phase parameter prior should be periodic. Arguments ------------- parameters (list of str) : Model parameter name list prior_transform (function) : prior transform from unit cube log_likelihood (function) : desired log likelihood function, should take in theta parameter vector (we will do live adaptation to numpyro format below) ''' Npar_c = len(Model.blm_phase_idx) Npar_nc = Model.Npar - Npar_c with numpyro.plate("Npar_nc",Npar_nc): theta_nc = numpyro.sample('theta_nc',dist.Uniform(0,1)) # theta = numpyro.sample('theta',dist.ImproperUniform(constraints.real,(),())) with numpyro.plate("Npar_c",Npar_c): theta_c = numpyro.sample('theta_c',dist.ImproperUniform(constraints.real,(),())) theta_tot = theta_nc cnt = 0 for idx in Model.blm_phase_idx: theta_tot = jnp.insert(theta_tot,idx,theta_c[cnt]) cnt += 1 theta = numpyro.deterministic("theta", theta_tot) theta_trans = numpyro.deterministic("theta_transformed", Model.prior(theta)) numpyro.factor('loglike',log_factor=Model.likelihood(theta_trans))
[docs] class numpyro_engine(): ''' Class for interfacing with the Numpyro HMC (NUTS) sampler. '''
[docs] @classmethod def define_engine(cls, lisaobj, Nburn, Nsamples, Nthreads, prog, seed, gpu=False): ## multithreading setup ## default to parallel chains chain_method = 'parallel' if gpu: ## check number of available gpus N_GPU = jax.local_device_count(backend='gpu') if N_GPU == 1: if Nthreads > 1: print("Nthreads = {}, but only one GPU is available. Setting numpyro chain_method to 'vectorized'.".format(Nthreads)) print(" WARNING: Vectorized GPU sampling is an experimental feature and is not stable for all BLIP configurations. If you get an XLA GEMM error, this is probably the cause; revert to standard parallelization in such cases.") chain_method = 'vectorized' elif N_GPU > 1: if Nthreads > N_GPU: print("Nthreads ({}) > N_GPU ({}) but vectorized parallel sampling has not yet been implemented. Setting Nthreads = N_GPU.".format(Nthreads,N_GPU)) Nthreads = N_GPU else: raise ValueError("GPU usage was requested but no GPUs are available!") else: numpyro.set_host_device_count(Nthreads) if seed is not None: rng_key = jax.random.PRNGKey(seed) else: raise TypeError("Numpyro sampler requires a defined seed.") ## if there are phase parameters, use the sph wrapper if len(lisaobj.Model.blm_phase_idx) > 0: kernel = NUTS(numpyro_model_sph) ## otherwise use the standard one else: kernel = NUTS(numpyro_model) engine = MCMC(kernel,num_warmup=Nburn,num_samples=Nsamples,num_chains=Nthreads,chain_method=chain_method,progress_bar=prog) # print npar print("Npar = " + str(lisaobj.Model.Npar)) return engine, lisaobj.Model.parameters, rng_key
[docs] @staticmethod def run_engine(engine,lisaobj,rng_key): # -------------------- Run HMC sampler --------------------------- print("Beginning sampling...") engine.run(rng_key,lisaobj.Model) print("Sampling complete. Retrieving posterior and plotting results...") ## retrive samples and reformat post_samples = np.array(engine.get_samples()['theta_transformed']).T return post_samples
[docs] def load_engine(resume_file): ## load model and parameters from previous checkpoint if os.path.isfile(resume_file): print("Loading interrupted analysis from last checkpoint...") with open(resume_file,'rb') as file: engine,state,chain = pickle.load(file) ## tell numpyro to start from current state engine.post_warmup_state = state ## grab rng_key for running rng_key = engine.post_warmup_state.rng_key else: raise TypeError("Checkpoint file <{}> does not exist. Cannot resume from checkpoint.".format(resume_file)) return engine, rng_key, chain
[docs] @staticmethod def run_engine_with_checkpointing(engine,lisaModel,rng_key,chain,checkpoint_file,Ntotal,checkpoint_at,resume=False): ''' Runs the numpyro sampler with sampler state checkpointing. Arguments ------------------- [fill in] checkpoint_sampling (bool) : When to checkpoint. Options: 'end' (only saves sampler state at the very end of the run) 'warmup' (saves after warmup phase and at end) 'interval' (saves after warmup, at end, and after every checkpoint_interval number of samples) Note: Generally not worth checkpointing while sampling for large models/datasets, as the recompliation and GPU off/onloading time will exceed the sampling time. resume (bool) : Whether the run is being resumed. If so, skip the warmup phase. Returns -------------------- post_samples (array) : Posterior samples. ''' if chain is None and not resume: if checkpoint_at=='warmup' or checkpoint_at=='interval': print("Beginning sampling, starting warmup phase...") ## run warmup phase engine.warmup(rng_key,lisaModel) state = engine.post_warmup_state print("Warmup phase complete. Checkpointing before initializing sampling...") ## save if dill.pickles([engine,state,chain]): temp_file = checkpoint_file + ".temp" with open(temp_file, "wb") as file: pickle.dump([engine,state,chain], file) shutil.move(temp_file, checkpoint_file) else: print("WARNING: Cannot write checkpoint file, job cannot resume if interrupted.") engine.post_warmup_state = state rng_key = engine.post_warmup_state.rng_key ## warn if the checkpointing spec is wonky, but continue on as if it were 'end' elif checkpoint_at!='end': print("WARNING: Invalid specification of checkpointing behavior (checkpoint_at='{}'). Sampler state will be saved at end of sampling.".format(checkpoint_at)) print("Initializing sampling...") while True: engine.run(rng_key,lisaModel) ## get state, current chain state = engine.last_state chain_update = engine.get_samples() if chain is not None: chain_updated = {} chain_updated['theta'] = jnp.append(chain['theta'],chain_update['theta'],axis=0) chain_updated['theta_transformed'] = [jnp.append(chain['theta_transformed'][i],chain_update['theta_transformed'][i]) for i in range(len(chain_update['theta_transformed']))] chain = chain_updated else: chain = chain_update if checkpoint_at=='interval': ## check to see if we have the desired number of samples yet Ncurrent = len(chain['theta_transformed'][0]) if Ncurrent >= Ntotal: break else: break print("Checkpointing ({}/{} samples)...".format(Ncurrent,Ntotal)) ## save if dill.pickles([engine,state,chain]): temp_file = checkpoint_file + ".temp" with open(temp_file, "wb") as file: pickle.dump([engine,state,chain], file) shutil.move(temp_file, checkpoint_file) else: print("Warning: Cannot write checkpoint file, job cannot resume if interrupted.") ## tell numpyro to start from current state engine.post_warmup_state = state ## grab rng_key for running rng_key = engine.post_warmup_state.rng_key ## save the final state print("Sampling complete. Saving final sampler state to {}".format(checkpoint_file)) if dill.pickles([engine,state,chain]): temp_file = checkpoint_file + ".temp" with open(temp_file, "wb") as file: pickle.dump([engine,state,chain], file) shutil.move(temp_file, checkpoint_file) else: print("Warning: Failed to save final state to checkpoint file, cannot resume sampling later.") ## retrive samples and reformat post_samples = np.array(chain['theta_transformed']).T return post_samples
#class numpyro_nested_engine(): # # ''' # Class for interfacing with the Numpyro nested sampling sampler. # ''' # # @classmethod # def define_engine(cls, lisaobj, Nsamples, seed, gpu=False): # # if seed is not None: # rng_key = jax.random.PRNGKey(seed) # else: # raise TypeError("Numpyro sampler requires a defined seed.") # # constructor_kwargs = {'verbose':True,'parameter_estimation':True, 'num_live_points':800, 'max_samples':10*Nsamples} # termination_kwargs = {'dlogZ':1e-4} # ## if there are phase parameters, use the sph wrapper # if len(lisaobj.Model.blm_phase_idx) > 0: # engine = NestedSampler(numpyro_model_sph, constructor_kwargs=constructor_kwargs, termination_kwargs=termination_kwargs) # ## otherwise use the standard one # else: # engine = NestedSampler(numpyro_model, constructor_kwargs=constructor_kwargs, termination_kwargs=termination_kwargs) # engine.num_samples = Nsamples ## engine = MCMC(kernel,num_warmup=Nburn,num_samples=Nsamples,num_chains=Nthreads,chain_method=chain_method,progress_bar=prog) # # # print npar # print("Npar = " + str(lisaobj.Model.Npar)) # # return engine, lisaobj.Model.parameters, rng_key # # @staticmethod # def run_engine(engine,lisaModel,rng_key,checkpoint_file): # # # -------------------- Run HMC sampler --------------------------- # print("Beginning sampling...") # engine.run(rng_key,lisaModel) # print("Sampling complete. Retrieving posterior and plotting results...") # ## retrive samples, resample from weighted samples, and reformat # _, new_key = jax.random.split(rng_key,2) # post_samples = np.array(engine.get_samples(new_key,num_samples=engine.num_samples)['theta_transformed']).T # # print("Sampling complete. Saving final sampler state to {}".format(checkpoint_file)) # if dill.pickles(engine): # temp_file = checkpoint_file + ".temp" # with open(temp_file, "wb") as file: # pickle.dump(engine, file) # shutil.move(temp_file, checkpoint_file) # else: # print("Warning: Failed to save final state to checkpoint file, cannot resume sampling later.") # # # return post_samples # # # def load_engine(resume_file): # # ## load model and parameters from previous checkpoint # if os.path.isfile(resume_file): # print("Loading interrupted analysis from last checkpoint...") # with open(resume_file,'rb') as file: # engine = pickle.load(file) # else: # raise TypeError("Checkpoint file <{}> does not exist. Cannot resume from checkpoint.".format(resume_file)) # # return engine #