Source code for dddm.samplers.emcee

"""
Do a likelihood fit. The class MCMCStatModel is used for fitting applying
the MCMC algorithm emcee.

MCMC is:
    slower than the nestle package; and
    harder to use since one has to choose the 'right' initial parameters

Nevertheless, the walkers give great insight in how the likelihood-function is
felt by the steps that the walkers make
"""
from warnings import warn
import datetime
import json
import os
import corner
import matplotlib.pyplot as plt
import numpy as np
from dddm import statistics, utils
import dddm
import typing as ty

export, __all__ = dddm.exporter()
log = dddm.utils.log


def default_emcee_save_dir():
    """The name of folders where to save results from the MCMCStatModel"""
    return 'emcee'


[docs]@export class MCMCStatModel(statistics.StatModel): def __init__( self, wimp_mass: ty.Union[float, int], cross_section: ty.Union[float, int], spectrum_class: ty.Union[dddm.DetectorSpectrum, dddm.GenSpectrum], prior: dict, tmp_folder: str, fit_parameters=('log_mass', 'log_cross_section', 'v_0', 'v_esc', 'density', 'k'), detector_name=None, verbose=False, notes='default', nwalkers=50, nsteps=100, remove_frac=0.2, emcee_thin=15 ): super().__init__(wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=spectrum_class, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=detector_name, verbose=verbose, notes=notes) self.nwalkers = nwalkers self.nsteps = nsteps # self.config['fit_parameters'] = ['log_mass', 'log_cross_section'] self.sampler = None self.pos = None self.log_dict = {'sampler': False, 'did_run': False, 'pos': False} self.remove_frac = remove_frac self.emcee_thin = emcee_thin def _set_pos(self, use_pos=None): """Set the starting position of the walkers""" self.log_dict['pos'] = True if use_pos is not None: self.log.info("using specified start position") self.pos = use_pos return nparameters = len(self.config['fit_parameters']) keys = statistics.get_prior_list()[:nparameters] ranges = [self.config['prior'][self.config['fit_parameters'][i]]['range'] for i in range(nparameters)] pos = [] for i, key in enumerate(keys): val = getattr(self, key) self.log.warning(f'{key} is {val}') a, b = ranges[i] start_at = val + 0.005 * val * np.random.randn(self.nwalkers, 1) start_at = np.clip(start_at, a, b) pos.append(start_at) pos = np.hstack(pos) self.pos = pos
[docs] def set_sampler(self, mult=True): """init the MCMC sampler""" # Do the import of emcee inside the class such that the package can be # loaded without emcee try: import emcee except ModuleNotFoundError: raise ModuleNotFoundError('package emcee not found. See README') ndim = len(self.config['fit_parameters']) self.sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.log_probability, args=([self.config['fit_parameters']]), ) self.log_dict['sampler'] = True
[docs] def run(self): self._fix_parameters() if not self.log_dict['sampler']: self.set_sampler() if not self.log_dict['pos']: self._set_pos() start = datetime.datetime.now() try: self.sampler.run_mcmc(self.pos, self.nsteps, progress=False) except ValueError as e: raise ValueError( f"MCMC did not finish due to a ValueError. Was running with\n" f"pos={self.pos.shape} nsteps = {self.nsteps}, walkers = " f"{self.nwalkers}, ndim = " f"{len(self.config['fit_parameters'])} for fit parameters " f"{self.config['fit_parameters']}") from e end = datetime.datetime.now() self.log_dict['did_run'] = True dt = (end - start).total_seconds() self.log.info(f"fit_done in {dt} s ({dt / 3600} h)") # Release config for writing! self.config = utils._immutable_to_dict(self.config) self.config['fit_time'] = dt
[docs] def show_walkers(self): if not self.log_dict['did_run']: self.run() labels = self.config['fit_parameters'] fig, axes = plt.subplots(len(labels), figsize=(10, 7), sharex=True) samples = self.sampler.get_chain() for i, label_i in enumerate(labels): ax = axes[i] ax.plot(samples[:, :, i], "k", alpha=0.3) ax.set_xlim(0, len(samples)) ax.set_ylabel(label_i) ax.yaxis.set_label_coords(-0.1, 0.5) axes[-1].set_xlabel("step number")
[docs] def show_corner(self): if not self.log_dict['did_run']: self.run() self.log.info( f"Removing a fraction of {self.remove_frac} of the samples, total" f"number of removed samples = {self.nsteps * self.remove_frac}") flat_samples = self._get_chain_flat_chain() truths = [getattr(self, prior_name) for prior_name in statistics.get_prior_list()[:len(self.config['fit_parameters'])]] corner.corner(flat_samples, labels=self.config['fit_parameters'], truths=truths)
def _get_chain_flat_chain(self): return self.sampler.get_chain( discard=int(self.nsteps * self.remove_frac), thin=self.emcee_thin, flat=True )
[docs] def save_results( self, save_to_dir=default_emcee_save_dir(), force_index=False): # save fit parameters to config self.config['fit_parameters'] = self.config['fit_parameters'] if not self.log_dict['did_run']: self.run() # open a folder where to save to results save_dir = dddm.context.open_save_dir( default_emcee_save_dir(), base_dir=save_to_dir, force_index=force_index) # save the config, chain and flattened chain with open(os.path.join(save_dir, 'config.json'), 'w') as fp: json.dump(utils.convert_dic_to_savable(self.config), fp, indent=4) np.save(os.path.join(save_dir, 'config.npy'), utils.convert_dic_to_savable(self.config)) save_at = os.path.join(save_dir, 'full_chain.npy') np.save(save_at, self.sampler.get_chain()) save_at = os.path.join(save_dir, 'flat_chain.npy') flat_chain = self._get_chain_flat_chain() np.save(save_at, flat_chain) self.config['save_dir'] = save_dir self.log.info("save_results::\tdone_saving")
def load_chain_emcee(load_from, item='latest'): files = os.listdir(load_from) if item == 'latest': try: item = files[-1] except ValueError: log.warning(files) item = 0 result = {} load_dir = os.path.join(load_from, str(item)) if not os.path.exists(load_dir): raise FileNotFoundError(f"Cannot find {load_dir} specified by arg: " f"{item}") log.info(f"loading {load_dir}") keys = ['config', 'full_chain', 'flat_chain'] for key in keys: result[key] = np.load( os.path.join( load_dir, key + '.npy'), allow_pickle=True) if key == 'config': result[key] = result[key].item() log.info(f"done loading\naccess result with:\n{keys}") return result def emcee_plots(result, save=False, plot_walkers=True, show=False): if not isinstance(save, bool): assert os.path.exists(save), f"invalid path '{save}'" info = r"$M_\chi}$=%.2f" % 10 ** np.float64(result['config']['log_mass']) for prior_key in result['config']['prior'].keys(): try: mean = result['config']['prior'][prior_key]['mean'] info += f"\n{prior_key} = {mean}" except KeyError: pass nsteps, nwalkers, ndim = np.shape(result['full_chain']) for str_inf in ['notes', 'start', 'fit_time', 'poisson', 'nwalkers', 'nsteps', 'n_energy_bins']: try: info += f"\n{str_inf} = %s" % result['config'][str_inf] if str_inf == 'start': info = info[:-7] if str_inf == 'fit_time': info += 's (%.1f h)' % (float(result['config'][str_inf]) / 3600.) except KeyError: pass info += "\nnwalkers = %s" % nwalkers info += "\nnsteps = %s" % nsteps labels = statistics.get_param_list()[:ndim] truths = [result['config'][prior_name] if prior_name in result['config'] else result['config']['prior'][prior_name]['mean'] for prior_name in statistics.get_prior_list()[:ndim]] fig = corner.corner( result['flat_chain'], labels=labels, range=[0.99999, 0.99999, 0.99999, 0.99999, 0.99999][:ndim], truths=truths, show_titles=True) fig.axes[1].set_title(f"{result['config']['detector']}", loc='left') fig.axes[1].text(0, 1, info, verticalalignment='top') if plot_walkers: _plot_walkers(result, truths, labels, save, show) _plt_cleanup(f"{save}corner.png", save, show) def _plot_walkers(result, truths, labels, save, show): fig, axes = plt.subplots(len(labels), figsize=(10, 5), sharex=True) for i, label_i in enumerate(labels): ax = axes[i] ax.plot(result['full_chain'][:, :, i], "k", alpha=0.3) ax.axhline(truths[i]) ax.set_xlim(0, len(result['full_chain'])) ax.set_ylabel(label_i) ax.yaxis.set_label_coords(-0.1, 0.5) axes[-1].set_xlabel("step number") _plt_cleanup(f"{save}flat_chain.png", save, show) def _plt_cleanup(name, save, show): if save: plt.savefig(name, dpi=200) if show: plt.show() else: plt.clf() plt.close()