Source code for dddm.samplers.multi_detectors

from __future__ import absolute_import, unicode_literals
import json
import os
import shutil
import numpy as np
import dddm
from .pymultinest import MultiNestSampler, convert_dic_to_savable
from .nestle import NestleSampler
from .ultranest import UltraNestSampler
import typing as ty

export, __all__ = dddm.exporter()


class _CombinedInference:
    allow_multiple_detectors = True

    def set_models(self):
        for c in self.sub_classes:
            self.log.debug(f'Printing det config for {c}')
            c.set_models()

    def _print_before_run(self):
        for c in self.sub_classes:
            self.log.debug(f'Printing det config for {c}')
            c._print_before_run()

    def _fix_parameters(self):
        """Fix the parameters of the sub classes"""
        for c in self.sub_classes:
            self.log.debug(f'Fixing parameters for {c}')
            c._fix_parameters()
        super()._fix_parameters(_do_evaluate_benchmark=False)

    def _log_probability_nested(self, theta):
        return np.sum([c._log_probability_nested(theta)
                       for c in self.sub_classes])

    def save_sub_configs(self, force_index=False):
        save_dir = self.get_save_dir(force_index=force_index)
        self.log.info(
            f'CombinedInference::\tSave configs of sub_experiments to {save_dir}')
        # save the config
        save_dir = os.path.join(save_dir, 'sub_exp_configs')
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        for c in self.sub_classes:
            if 'logging' not in c.config:
                raise ValueError(f'{c} does not have logging in config ({list(c.config.keys())})')
            save_as = os.path.join(f'{save_dir}', f'{c.config["detector"]}_')
            with open(f'{save_as}config.json', 'w') as file:
                json.dump(convert_dic_to_savable(c.config), file, indent=4)
            np.save(f'{save_as}config.npy', convert_dic_to_savable(c.config))
            shutil.copy(c.config['logging'], save_as +
                        c.config['logging'].split('/')[-1])
            self.log.info('save_sub_configs::\tdone_saving')


[docs]@export class CombinedMultinest(_CombinedInference, MultiNestSampler): def __init__( self, wimp_mass: ty.Union[float, int], cross_section: ty.Union[float, int], spectrum_class: ty.List[ty.Union[dddm.DetectorSpectrum, dddm.GenSpectrum]], prior: dict, tmp_folder: str, results_dir: str = None, fit_parameters=('log_mass', 'log_cross_section', 'v_0', 'v_esc', 'density', 'k'), detector_name=None, verbose=False, notes='default', nlive=1024, tol=0.1, ): assert detector_name is not None # Make list explicit spectrum_classes = spectrum_class del spectrum_class MultiNestSampler.__init__(self, wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=spectrum_classes, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=detector_name, verbose=verbose, results_dir=results_dir, notes=notes, nlive=nlive, tol=tol, ) if len(spectrum_classes) < 2: self.log.warning( "Don't use this class for single experiments! Use NestedSamplerStatModel instead") self.sub_detectors = spectrum_classes self.config['sub_sets'] = [str(sp) for sp in spectrum_classes] self.sub_classes = [ MultiNestSampler(wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=one_class, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=one_class.detector_name, verbose=verbose, notes=notes, ) for one_class in self.sub_detectors ] self.log.debug(f'Sub detectors are set: {self.sub_classes}')
[docs]@export class CombinedNestle(_CombinedInference, NestleSampler): def __init__( self, wimp_mass: ty.Union[float, int], cross_section: ty.Union[float, int], spectrum_class: ty.List[ty.Union[dddm.DetectorSpectrum, dddm.GenSpectrum]], prior: dict, tmp_folder: str, results_dir: str = None, fit_parameters=('log_mass', 'log_cross_section', 'v_0', 'v_esc', 'density', 'k'), detector_name=None, verbose=False, notes='default', nlive=1024, tol=0.1, ): assert detector_name is not None # Make list explicit spectrum_classes = spectrum_class del spectrum_class NestleSampler.__init__(self, wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=spectrum_classes, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=detector_name, verbose=verbose, results_dir=results_dir, notes=notes, nlive=nlive, tol=tol, ) if len(spectrum_classes) < 2: self.log.warning( "Don't use this class for single experiments! Use NestedSamplerStatModel instead") self.sub_detectors = spectrum_classes self.config['sub_sets'] = [str(sp) for sp in spectrum_classes] self.sub_classes = [ MultiNestSampler(wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=one_class, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=one_class.detector_name, verbose=verbose, notes=notes, ) for one_class in self.sub_detectors ] self.log.debug(f'Sub detectors are set: {self.sub_classes}')
[docs]@export class CombinedUltraNest(_CombinedInference, UltraNestSampler): def __init__( self, wimp_mass: ty.Union[float, int], cross_section: ty.Union[float, int], spectrum_class: ty.List[ty.Union[dddm.DetectorSpectrum, dddm.GenSpectrum]], prior: dict, tmp_folder: str, results_dir: str = None, fit_parameters=('log_mass', 'log_cross_section', 'v_0', 'v_esc', 'density', 'k'), detector_name=None, verbose=False, notes='default', nlive=1024, tol=0.1, ): assert detector_name is not None # Make list explicit spectrum_classes = spectrum_class del spectrum_class UltraNestSampler.__init__(self, wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=spectrum_classes, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=detector_name, verbose=verbose, results_dir=results_dir, notes=notes, nlive=nlive, tol=tol, ) if len(spectrum_classes) < 2: self.log.warning( "Don't use this class for single experiments! Use NestedSamplerStatModel instead") self.sub_detectors = spectrum_classes self.config['sub_sets'] = [str(sp) for sp in spectrum_classes] self.sub_classes = [ UltraNestSampler(wimp_mass=wimp_mass, cross_section=cross_section, spectrum_class=one_class, prior=prior, tmp_folder=tmp_folder, fit_parameters=fit_parameters, detector_name=one_class.detector_name, verbose=verbose, notes=notes, ) for one_class in self.sub_detectors ] self.log.debug(f'Sub detectors are set: {self.sub_classes}')