Source code for pycbc.inference.io.dynesty

# Copyright (C) 2019 Collin Capano, Sumit Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# self.option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Provides IO for the dynesty sampler.
"""

import argparse
import numpy
from pycbc.io.hdf import (dump_state, load_state)

from .base_nested_sampler import BaseNestedSamplerFile
from .posterior import write_samples_to_file, read_raw_samples_from_file

[docs]class CommonNestedMetadataIO(object): """Provides functions for reading/writing dynesty metadata to file. """
[docs] @staticmethod def extra_args_parser(parser=None, skip_args=None, **kwargs): """Create a parser to parse sampler-specific arguments for loading samples. Parameters ---------- parser : argparse.ArgumentParser, optional Instead of creating a parser, add arguments to the given one. If none provided, will create one. skip_args : list, optional Don't parse the given options. Options should be given as the option string, minus the '--'. For example, ``skip_args=['iteration']`` would cause the ``--iteration`` argument not to be included. \**kwargs : All other keyword arguments are passed to the parser that is created. Returns ------- parser : argparse.ArgumentParser An argument parser with th extra arguments added. actions : list of argparse.Action A list of the actions that were added. """ if parser is None: parser = argparse.ArgumentParser(**kwargs) elif kwargs: raise ValueError("No other keyword arguments should be provded if " "a parser is provided.") if skip_args is None: skip_args = [] actions = [] if 'raw_samples' not in skip_args: act = parser.add_argument( "--raw-samples", action='store_true', default=False, help="Raw samples are the unweighted samples obtained from " "dynesty sampler." "Default value is False which means raw samples are" "weighted by log-weight array obtained from the sampler.") actions.append(act) return parser, actions
[docs]class DynestyFile(CommonNestedMetadataIO, BaseNestedSamplerFile): """Class to handle file IO for the ``dynesty`` sampler.""" name = 'dynesty_file'
[docs] def read_raw_samples(self, fields, raw_samples=False, **kwargs): samples = read_raw_samples_from_file(self, fields, **kwargs) logwt = read_raw_samples_from_file(self, ['logwt'], **kwargs)['logwt'] loglikelihood = read_raw_samples_from_file(self, ['loglikelihood'], **kwargs)['loglikelihood'] logz = self.attrs.get('log_evidence') if raw_samples is False: weights = numpy.exp(logwt - logz) N = len(weights) positions = (numpy.random.random() + numpy.arange(N)) / N idx = numpy.zeros(N, dtype=numpy.int) cumulative_sum = numpy.cumsum(weights) i, j = 0, 0 while i < N: if positions[i] < cumulative_sum[j]: idx[i] = j i += 1 else: j += 1 numpy.random.shuffle(idx) post = {'loglikelihood': loglikelihood[idx]} for i, param in enumerate(self.variable_params): post[param] = samples[param][idx] return post else: return samples
[docs] def write_pickled_data_into_checkpoint_file(self, state): """Dump the sampler state into checkpoint file """ if 'sampler_info/saved_state' not in self: self.create_group('sampler_info/saved_state') dump_state(state, self, path='sampler_info/saved_state')
[docs] def read_pickled_data_from_checkpoint_file(self): """Load the sampler state (pickled) from checkpoint file """ return load_state(self, path='sampler_info/saved_state')
[docs] def write_raw_samples(self, data, parameters=None): """Write the nested samples to the file """ if 'samples' not in self: self.create_group('samples') write_samples_to_file(self, data, parameters=parameters, group='samples')
[docs] def validate(self): """Runs a validation test. This checks that a samples group exist, and that pickeled data can be loaded. Returns ------- bool : Whether or not the file is valid as a checkpoint file. """ try: if 'sampler_info/saved_state' in self: load_state(self, path='sampler_info/saved_state') checkpoint_valid = True except KeyError: checkpoint_valid = False return checkpoint_valid