# Copyright (C) 2018 Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
This package provides classes and functions for evaluating Bayesian statistics
assuming various noise models.
"""
import logging
from pkg_resources import iter_entry_points as _iter_entry_points
from .base import BaseModel
from .base_data import BaseDataModel
from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano,
TestPrior, TestPosterior)
from .gaussian_noise import GaussianNoise
from .marginalized_gaussian_noise import MarginalizedPhaseGaussianNoise
from .marginalized_gaussian_noise import MarginalizedPolarization
from .marginalized_gaussian_noise import MarginalizedHMPolPhase
from .marginalized_gaussian_noise import MarginalizedTime
from .brute_marg import BruteParallelGaussianMarginalize
from .brute_marg import BruteLISASkyModesMarginalize
from .gated_gaussian_noise import (GatedGaussianNoise, GatedGaussianMargPol)
from .single_template import SingleTemplate
from .relbin import Relative, RelativeTime, RelativeTimeDom
from .hierarchical import (HierarchicalModel, MultiSignalModel,
JointPrimaryMarginalizedModel)
# Used to manage a model instance across multiple cores or MPI
_global_instance = None
def _call_global_model(*args, **kwds):
"""Private function for global model (needed for parallelization)."""
return _global_instance(*args, **kwds) # pylint:disable=not-callable
def _call_global_model_logprior(*args, **kwds):
"""Private function for a calling global's logprior.
This is needed for samplers that use a separate function for the logprior,
like ``emcee_pt``.
"""
# pylint:disable=not-callable
return _global_instance(*args, callstat='logprior', **kwds)
[docs]
class CallModel(object):
"""Wrapper class for calling models from a sampler.
This class can be called like a function, with the parameter values to
evaluate provided as a list in the same order as the model's
``variable_params``. In that case, the model is updated with the provided
parameters and then the ``callstat`` retrieved. If ``return_all_stats`` is
set to ``True``, then all of the stats specified by the model's
``default_stats`` will be returned as a tuple, in addition to the stat
value.
The model's attributes are promoted to this class's namespace, so that any
attribute and method of ``model`` may be called directly from this class.
This class must be initalized prior to the creation of a ``Pool`` object.
Parameters
----------
model : Model instance
The model to call.
callstat : str
The statistic to call.
return_all_stats : bool, optional
Whether or not to return all of the other statistics along with the
``callstat`` value.
Examples
--------
Create a wrapper around an instance of the ``TestNormal`` model, with the
``callstat`` set to ``logposterior``:
>>> from pycbc.inference.models import TestNormal, CallModel
>>> model = TestNormal(['x', 'y'])
>>> call_model = CallModel(model, 'logposterior')
Now call on a set of parameter values:
>>> call_model([0.1, -0.2])
(-1.8628770664093453, (0.0, 0.0, -1.8628770664093453))
Note that a tuple of all of the model's ``default_stats`` were returned in
addition to the ``logposterior`` value. We can shut this off by toggling
``return_all_stats``:
>>> call_model.return_all_stats = False
>>> call_model([0.1, -0.2])
-1.8628770664093453
Attributes of the model can be called from the call model. For example:
>>> call_model.variable_params
('x', 'y')
"""
def __init__(self, model, callstat, return_all_stats=True):
self.model = model
self.callstat = callstat
self.return_all_stats = return_all_stats
def __getattr__(self, attr):
"""Adds the models attributes to self."""
return getattr(self.model, attr)
def __call__(self, param_values, callstat=None, return_all_stats=None):
"""Updates the model with the given parameter values, then calls the
call function.
Parameters
----------
param_values : list of float
The parameter values to test. Assumed to be in the same order as
``model.sampling_params``.
callstat : str, optional
Specify which statistic to call. Default is to call whatever self's
``callstat`` is set to.
return_all_stats : bool, optional
Whether or not to return all stats in addition to the ``callstat``
value. Default is to use self's ``return_all_stats``.
Returns
-------
stat : float
The statistic returned by the ``callfunction``.
all_stats : tuple, optional
The values of all of the model's ``default_stats`` at the given
param values. Any stat that has not be calculated is set to
``numpy.nan``. This is only returned if ``return_all_stats`` is
set to ``True``.
"""
if callstat is None:
callstat = self.callstat
if return_all_stats is None:
return_all_stats = self.return_all_stats
params = dict(zip(self.model.sampling_params, param_values))
self.model.update(**params)
val = getattr(self.model, callstat)
if return_all_stats:
return val, self.model.get_current_stats()
else:
return val
[docs]
def read_from_config(cp, **kwargs):
"""Initializes a model from the given config file.
The section must have a ``name`` argument. The name argument corresponds to
the name of the class to initialize.
Parameters
----------
cp : WorkflowConfigParser
Config file parser to read.
\**kwargs :
All other keyword arguments are passed to the ``from_config`` method
of the class specified by the name argument.
Returns
-------
cls
The initialized model.
"""
# use the name to get the distribution
name = cp.get("model", "name")
return get_model(name).from_config(cp, **kwargs)
_models = {_cls.name: _cls for _cls in (
TestEggbox,
TestNormal,
TestRosenbrock,
TestVolcano,
TestPosterior,
TestPrior,
GaussianNoise,
MarginalizedPhaseGaussianNoise,
MarginalizedPolarization,
MarginalizedHMPolPhase,
MarginalizedTime,
BruteParallelGaussianMarginalize,
BruteLISASkyModesMarginalize,
GatedGaussianNoise,
GatedGaussianMargPol,
SingleTemplate,
Relative,
RelativeTime,
HierarchicalModel,
MultiSignalModel,
RelativeTimeDom,
JointPrimaryMarginalizedModel,
)}
class _ModelManager(dict):
"""Sub-classes dictionary to manage the collection of available models.
The first time this is called, any plugin models that are available will be
added to the dictionary before returning.
"""
def __init__(self, *args, **kwargs):
self.retrieve_plugins = True
super().__init__(*args, **kwargs)
def add_model(self, model):
"""Adds a model to the dictionary.
If the given model has the same name as a model already in the
dictionary, the original model will be overridden. A warning will be
printed in that case.
"""
if super().__contains__(model.name):
logging.warning("Custom model %s will override a model of the "
"same name. If you don't want this, change the "
"model's name attribute and restart.", model.name)
self[model.name] = model
def add_plugins(self):
"""Adds any plugin models that are available.
This will only add the plugins if ``self.retrieve_plugins = True``.
After this runs, ``self.retrieve_plugins`` is set to ``False``, so that
subsequent calls to this will no re-add models.
"""
if self.retrieve_plugins:
for plugin in _iter_entry_points('pycbc.inference.models'):
self.add_model(plugin.resolve())
self.retrieve_plugins = False
def __len__(self):
self.add_plugins()
super().__len__()
def __contains__(self, key):
self.add_plugins()
return super().__contains__(key)
def get(self, *args):
self.add_plugins()
return super().get(*args)
def popitem(self):
self.add_plugins()
return super().popitem()
def pop(self, *args):
try:
return super().pop(*args)
except KeyError:
self.add_plugins()
return super().pop(*args)
def keys(self):
self.add_plugins()
return super().keys()
def values(self):
self.add_plugins()
return super().values()
def items(self):
self.add_plugins()
return super().items()
def __iter__(self):
self.add_plugins()
return super().__iter__()
def __repr__(self):
self.add_plugins()
return super().__repr__()
def __getitem__(self, item):
try:
return super().__getitem__(item)
except KeyError:
self.add_plugins()
return super().__getitem__(item)
def __delitem__(self, *args, **kwargs):
try:
super().__delitem__(*args, **kwargs)
except KeyError:
self.add_plugins()
super().__delitem__(*args, **kwargs)
models = _ModelManager(_models)
[docs]
def get_models():
"""Returns the dictionary of current models.
Ensures that plugins are added to the dictionary first.
"""
models.add_plugins()
return models
[docs]
def get_model(model_name):
"""Retrieve the given model.
Parameters
----------
model_name : str
The name of the model to get.
Returns
-------
model :
The requested model.
"""
return get_models()[model_name]
[docs]
def available_models():
"""List the currently available models."""
return list(get_models().keys())
[docs]
def register_model(model):
"""Makes a custom model available to PyCBC.
The provided model will be added to the dictionary of models that PyCBC
knows about, using the model's ``name`` attribute. If the ``name`` is the
same as a model that already exists in PyCBC, a warning will be printed.
Parameters
----------
model : pycbc.inference.models.base.BaseModel
The model to use. The model should be a sub-class of
:py:class:`BaseModel <pycbc.inference.models.base.BaseModel>` to ensure
it has the correct API for use within ``pycbc_inference``.
"""
get_models().add_model(model)