from __future__ import absolute_import, unicode_literals
import datetime
import json
import os
import shutil
import numpy as np
import dddm
from .pymultinest import MultiNestSampler, multinest_corner, convert_dic_to_savable
log = dddm.utils.log
export, __all__ = dddm.exporter()
[docs]@export
class NestleSampler(MultiNestSampler):
[docs] def run(self):
self._fix_parameters()
self._print_before_run()
# Do the import of nestle inside the class such that the package can be
# loaded without nestle
try:
import nestle
except ModuleNotFoundError:
raise ModuleNotFoundError(
'package nestle not found. See README for installation')
self.log.debug('We made it to my core function, lets do that optimization')
method = 'multi' # use MutliNest algorithm
ndim = len(self.config['fit_parameters'])
tol = self.config['tol'] # the stopping criterion
assert_str = f"Unknown configuration of fit pars: {self.config['fit_parameters']}"
assert tuple(self.config["fit_parameters"]) == tuple(
self.known_parameters[:ndim]), assert_str
self.log.warning(f'run_nestle::\tstart_fit for {ndim} parameters')
start = datetime.datetime.now()
try:
self.result = nestle.sample(
self._log_probability_nested,
self._log_prior_transform_nested,
ndim,
method=method,
npoints=self.config['nlive'],
maxiter=self.config.get('max_iter'),
dlogz=tol)
except ValueError as e:
self.config['fit_time'] = -1
self.log.error(
f'Nestle did not finish due to a ValueError. Was running with'
f'{self.config["fit_parameters"]}')
raise e
end = datetime.datetime.now()
dt = (end - start).total_seconds()
self.log.info(f'fit_done in {dt} s ({dt / 3600} h)')
self.config = dddm.utils._immutable_to_dict(self.config)
self.config['fit_time'] = dt
self.log_dict['did_run'] = True
self.log.info('Finished with running optimizer!')
[docs] def get_summary(self):
self.log.info(
"getting the summary (or at least trying) let's first see if I did run"
)
self.check_did_run()
# Do the import of nestle inside the class such that the package can be
# loaded without nestle
try:
import nestle
except ModuleNotFoundError:
raise ModuleNotFoundError(
'package nestle not found. See README for installation')
# taken from mattpitkin.github.io/samplers-demo/pages/samplers-samplers-everywhere/#Nestle # noqa
# estimate of the statistical uncertainty on logZ
logZerrnestle = np.sqrt(self.result.h / self.config['nlive'])
# re-scale weights to have a maximum of one
nweights = self.result.weights / np.max(self.result.weights)
# get the probability of keeping a sample from the weights
keepidx = np.where(np.random.rand(len(nweights)) < nweights)[0]
# get the posterior samples
samples_nestle = self.result.samples[keepidx, :]
resdict = {
'nestle_nposterior': len(samples_nestle),
'nestle_time': self.config['fit_time'],
'nestle_logZ': self.result.logz,
'nestle_logZerr': logZerrnestle,
'summary': self.result.summary(),
}
p, cov = nestle.mean_and_cov(
self.result.samples, self.result.weights)
for i, key in enumerate(self.config['fit_parameters']):
resdict[key + '_fit_res'] = (
'{0:5.2f} +/- {1:5.2f}'.format(p[i], np.sqrt(cov[i, i])))
self.log.info(f'\t, {key}, {resdict[key + "_fit_res"]}')
if 'log_' in key:
resdict[key[4:] + '_fit_res'] = '%.3g +/- %.2g' % (
10. ** p[i], 10. ** (p[i]) * np.log(10) * np.sqrt(cov[i, i]))
self.log.info(
f'\t, {key[4:]}, {resdict[key[4:] + "_fit_res"]}')
resdict['best_fit'] = p
resdict['cov_matrix'] = cov
resdict['weighted_samples'] = samples_nestle
self.log.info('Alright we got all the info we need')
return resdict
[docs] def save_results(self, force_index=False):
self.log.info('Saving results after checking we did run')
# save fit parameters to config
self.check_did_run()
save_dir = self.get_save_dir(force_index=force_index)
fit_summary = self.get_summary()
self.log.info(f'storing in {save_dir}')
# save the config, chain and flattened chain
pid_id = 'pid' + str(os.getpid()) + '_'
with open(os.path.join(save_dir, f'{pid_id}config.json'), 'w') as file:
json.dump(convert_dic_to_savable(self.config), file, indent=4)
with open(os.path.join(save_dir, f'{pid_id}res_dict.json'), 'w') as file:
json.dump(convert_dic_to_savable(fit_summary), file, indent=4)
np.save(os.path.join(save_dir, f'{pid_id}config.npy'),
convert_dic_to_savable(self.config))
np.save(os.path.join(save_dir, f'{pid_id}weighted_samples.npy'),
fit_summary.get('weighted_samples'))
np.save(os.path.join(save_dir, f'{pid_id}res_dict.npy'),
convert_dic_to_savable(fit_summary))
for col in self.result.keys():
if col == 'samples' or not isinstance(col, dict):
store_at = os.path.join(
save_dir,
pid_id + col + '.npy')
np.save(store_at, self.result[col])
else:
np.save(os.path.join(save_dir, pid_id + col + '.npy'),
convert_dic_to_savable(self.result[col]))
if 'logging' in self.config:
store_at = os.path.join(save_dir,
self.config['logging'].split('/')[-1])
shutil.copy(self.config['logging'], store_at)
self.log.info('save_results::\tdone_saving')
[docs] def show_corner(self):
self.check_did_save()
save_dir = self.log_dict['saved_in']
combined_results = load_nestle_samples_from_file(save_dir)
nestle_corner(combined_results, save_dir)
self.log.info('Enjoy the plot. Maybe you do want to save it too?')
def load_nestle_samples_from_file(load_dir):
log.info(f'load_nestle_samples::\tloading {load_dir}')
keys = ['config', 'res_dict', 'h', 'logl', 'logvol', 'logz', 'logzerr',
'ncall', 'niter', 'samples', 'weights', 'weighted_samples']
result = {}
files_in_dir = os.listdir(load_dir)
for key in keys:
for file in files_in_dir:
if key + '.npy' in file:
result[key] = np.load(
os.path.join(load_dir, file),
allow_pickle=True)
break
else:
raise FileNotFoundError(f'No {key} in {load_dir} only:\n{files_in_dir}')
if key in ['config', 'res_dict']:
result[key] = result[key].item()
log.info(
f"load_nestle_samples::\tdone loading\naccess result with:\n{keys}")
return result
def nestle_corner(result, save=False):
multinest_corner(result, save)