"""
This modules provides class for using the nessai sampler package for parameter
estimation.
Documentation for nessai: https://nessai.readthedocs.io/en/latest/
"""
import ast
import logging
import os
import nessai.flowsampler
import nessai.model
import nessai.livepoint
import nessai.utils.multiprocessing
import nessai.utils.settings
import numpy
import numpy.lib.recfunctions as rfn
from .base import BaseSampler, setup_output
from .base_mcmc import get_optional_arg_from_config
from ..io import NessaiFile, loadfile
from ...pool import choose_pool
[docs]
class NessaiSampler(BaseSampler):
"""Class to construct a FlowSampler from the nessai package."""
name = "nessai"
_io = NessaiFile
def __init__(
self,
model,
loglikelihood_function,
nlive=1000,
nprocesses=1,
use_mpi=False,
run_kwds=None,
extra_kwds=None,
):
super().__init__(model)
self.nlive = nlive
self.model_call = NessaiModel(self.model, loglikelihood_function)
self.extra_kwds = extra_kwds if extra_kwds is not None else {}
self.run_kwds = run_kwds if run_kwds is not None else {}
nessai.utils.multiprocessing.initialise_pool_variables(self.model_call)
self.pool = choose_pool(mpi=use_mpi, processes=nprocesses)
self.nprocesses = nprocesses
self._sampler = None
self._nested_samples = None
self._posterior_samples = None
self._logz = None
self._dlogz = None
self.checkpoint_file = None
self.resume_data = None
@property
def io(self):
return self._io
@property
def model_stats(self):
pass
@property
def samples(self):
"""The raw nested samples including the corresponding weights"""
if self._sampler.ns.nested_samples:
ns = numpy.array(self._sampler.ns.nested_samples)
samples = nessai.livepoint.live_points_to_dict(
ns,
self.model.sampling_params,
)
samples["logwt"] = self._sampler.ns.state.log_posterior_weights
samples["loglikelihood"] = ns["logL"]
samples["logprior"] = ns["logP"]
samples["it"] = ns["it"]
else:
samples = {}
return samples
[docs]
def run(self, **kwargs):
"""Run the sampler"""
default_kwds, default_run_kwds = self.get_default_kwds(
importance_nested_sampler=self.extra_kwds.get(
"importance_nested_sampler", False
)
)
extra_kwds = self.extra_kwds.copy()
run_kwds = self.run_kwds.copy()
# Output in kwargs takes priority of extra kwds.
output = kwargs.pop("output", extra_kwds.pop("output", None))
# If neither have been specified, use the path from the checkpoint file
if output is None:
output = os.path.join(
os.path.dirname(os.path.abspath(self.checkpoint_file)),
"outdir_nessai",
)
if kwargs is not None:
logging.info("Updating keyword arguments with %s", kwargs)
extra_kwds.update(
{k: v for k, v in kwargs.items() if k in default_kwds}
)
run_kwds.update(
{k: v for k, v in kwargs.items() if k in default_run_kwds}
)
if self._sampler is None:
logging.info("Initialising nessai FlowSampler")
self._sampler = nessai.flowsampler.FlowSampler(
self.model_call,
output=output,
pool=self.pool,
n_pool=self.nprocesses,
close_pool=False,
signal_handling=False,
resume_data=self.resume_data,
checkpoint_callback=self.checkpoint_callback,
**extra_kwds,
)
logging.info("Starting sampling with nessai")
self._sampler.run(**run_kwds)
[docs]
@staticmethod
def get_default_kwds(importance_nested_sampler=False):
"""Return lists of all allowed keyword arguments for nessai.
Returns
-------
default_kwds : list
List of keyword arguments that can be passed to FlowSampler
run_kwds: list
List of keyword arguments that can be passed to FlowSampler.run
"""
return nessai.utils.settings.get_all_kwargs(
importance_nested_sampler=importance_nested_sampler,
split_kwargs=True,
)
[docs]
@classmethod
def from_config(
cls, cp, model, output_file=None, nprocesses=1, use_mpi=False
):
"""
Loads the sampler from the given config file.
"""
section = "sampler"
# check name
assert (
cp.get(section, "name") == cls.name
), "name in section [sampler] must match mine"
if cp.has_option(section, "importance_nested_sampler"):
importance_nested_sampler = cp.get(
section,
"importance_nested_sampler",
)
else:
importance_nested_sampler = False
# Requires additional development work, see the model class below
if importance_nested_sampler is True:
raise NotImplementedError(
"Importance nested sampler is not currently supported"
)
default_kwds, default_run_kwds = cls.get_default_kwds(
importance_nested_sampler
)
# Keyword arguments the user cannot configure via the config
remove_kwds = [
"pool",
"n_pool",
"close_pool",
"signal_handling",
"checkpoint_callback",
]
for kwd in remove_kwds:
default_kwds.pop(kwd, None)
default_run_kwds.pop(kwd, None)
kwds = {}
run_kwds = {}
# ast.literal_eval is used here since specifying a dictionary with all
# various types would be difficult. However, one may wish to revisit
# this in future, e.g. if evaluating code is a concern.
for d_out, d_defaults in zip(
[kwds, run_kwds], [default_kwds, default_run_kwds]
):
for k in d_defaults.keys():
if cp.has_option(section, k):
option = cp.get(section, k)
try:
# This will fail for e.g. a string with an underscore
option = ast.literal_eval(option)
except ValueError:
pass
d_out[k] = option
# Specified kwds
ignore_kwds = {"nlive", "name"}
invalid_kwds = (
cp[section].keys()
- set().union(kwds.keys(), run_kwds.keys())
- ignore_kwds
)
if invalid_kwds:
raise RuntimeError(
f"Config contains unknown options: {invalid_kwds}"
)
logging.info("nessai keyword arguments: %s", kwds)
logging.info("nessai run keyword arguments: %s", run_kwds)
loglikelihood_function = get_optional_arg_from_config(
cp, section, "loglikelihood-function"
)
obj = cls(
model,
loglikelihood_function=loglikelihood_function,
nprocesses=nprocesses,
use_mpi=use_mpi,
run_kwds=run_kwds,
extra_kwds=kwds,
)
# Do not need to check number of samples for a nested sampler
setup_output(obj, output_file, check_nsamples=False)
if not obj.new_checkpoint:
obj.resume_from_checkpoint()
return obj
[docs]
def set_initial_conditions(
self,
initial_distribution=None,
samples_file=None,
):
"""Sets up the starting point for the sampler.
This is not used for nessai.
"""
[docs]
def checkpoint_callback(self, state):
"""Callback for checkpointing.
This will be called periodically by nessai.
"""
for fn in [self.checkpoint_file, self.backup_file]:
with self.io(fn, "a") as fp:
fp.write_pickled_data_into_checkpoint_file(state)
self.write_results(fn)
[docs]
def checkpoint(self):
"""Checkpoint the sampler"""
self.checkpoint_callback(self._sampler.ns)
[docs]
def resume_from_checkpoint(self):
"""Reads the resume data from the checkpoint file."""
try:
with loadfile(self.checkpoint_file, "r") as fp:
self.resume_data = fp.read_pickled_data_from_checkpoint_file()
logging.info(
"Found valid checkpoint file: %s", self.checkpoint_file
)
except Exception as e:
logging.info("Failed to load checkpoint file with error: %s", e)
[docs]
def finalize(self):
"""Finalize sampling"""
logz = self._sampler.ns.log_evidence
dlogz = self._sampler.ns.log_evidence_error
logging.info("log Z, dlog Z: %s, %s", logz, dlogz)
self.checkpoint()
[docs]
def write_results(self, filename):
"""Write the results to a given file.
Writes the nested samples, log-evidence and log-evidence error.
"""
with self.io(filename, "a") as fp:
fp.write_raw_samples(self.samples)
fp.write_logevidence(
self._sampler.ns.log_evidence,
self._sampler.ns.log_evidence_error,
)
[docs]
class NessaiModel(nessai.model.Model):
"""Wrapper for PyCBC Inference model class for use with nessai.
Parameters
----------
model : inference.BaseModel instance
A model instance from PyCBC.
loglikelihood_function : str
Name of the log-likelihood method to call.
"""
def __init__(self, model, loglikelihood_function=None):
self.model = model
self.names = list(model.sampling_params)
# Configure the log-likelihood function
if loglikelihood_function is None:
loglikelihood_function = "loglikelihood"
self.loglikelihood_function = loglikelihood_function
# Configure the priors bounds
bounds = {}
for dist in model.prior_distribution.distributions:
bounds.update(
**{
k: [v.min, v.max]
for k, v in dist.bounds.items()
if k in self.names
}
)
self.bounds = bounds
# Prior and likelihood are not vectorised
self.vectorised_likelihood = False
self.vectorised_prior = False
# Use the pool for computing the prior
self.parallelise_prior = True
[docs]
def to_dict(self, x):
"""Convert a nessai live point array to a dictionary"""
return {n: x[n].item() for n in self.names}
[docs]
def to_live_points(self, x):
"""Convert to the structured arrays used by nessai"""
# It is possible this could be made faster
return nessai.livepoint.numpy_array_to_live_points(
rfn.structured_to_unstructured(x),
self.names,
)
[docs]
def new_point(self, N=1):
"""Draw a new point"""
return self.to_live_points(self.model.prior_rvs(size=N))
[docs]
def new_point_log_prob(self, x):
"""Log-probability for the ``new_point`` method"""
return self.batch_evaluate_log_prior(x)
[docs]
def log_prior(self, x):
"""Compute the log-prior"""
self.model.update(**self.to_dict(x))
return self.model.logprior
[docs]
def log_likelihood(self, x):
"""Compute the log-likelihood"""
self.model.update(**self.to_dict(x))
return getattr(self.model, self.loglikelihood_function)
[docs]
def from_unit_hypercube(self, x):
"""Map from the unit-hypercube to the prior."""
# Needs to be implemented for importance nested sampler
# This method is already available in pycbc but the inverse is not
raise NotImplementedError
[docs]
def to_unit_hypercube(self, x):
"""Map to the unit-hypercube to the prior."""
# Needs to be implemented for importance nested sampler
raise NotImplementedError