Source code for pycbc.inference.models.brute_marg

# Copyright (C) 2020 Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""This module provides model classes that do brute force marginalization
using at the likelihood level.
"""
import math
import logging
import numpy

from pycbc.pool import BroadcastPool as Pool
from scipy.special import logsumexp

from .gaussian_noise import BaseGaussianNoise
from .tools import draw_sample

_model = None
[docs] class likelihood_wrapper(object): def __init__(self, model): global _model _model = model def __call__(self, params): global _model _model.update(**params) loglr = _model.loglr return loglr, _model.current_stats
[docs] class BruteParallelGaussianMarginalize(BaseGaussianNoise): name = "brute_parallel_gaussian_marginalize" def __init__(self, variable_params, cores=10, base_model=None, marginalize_phase=None, **kwds): super().__init__(variable_params, **kwds) from pycbc.inference.models import models self.model = models[base_model](variable_params, **kwds) self.call = likelihood_wrapper(self.model) # size of pool for each likelihood call self.pool = Pool(int(cores)) # Only one for now, but can be easily extended self.phase = None if marginalize_phase: samples = int(marginalize_phase) self.phase = numpy.linspace(0, 2.0 * numpy.pi, samples) @property def _extra_stats(self): stats = self.model._extra_stats stats.append('maxl_phase') if 'maxl_loglr' not in stats: stats.append('maxl_loglr') return stats def _loglr(self): if self.phase is not None: params = [] for p in self.phase: pref = self.current_params.copy() pref['coa_phase'] = p params.append(pref) vals = list(self.pool.map(self.call, params)) loglr = numpy.array([v[0] for v in vals]) # get the maxl values if 'maxl_loglr' not in self.model._extra_stats: maxl_loglrs = loglr else: maxl_loglrs = numpy.array([v[1]['maxl_loglr'] for v in vals]) maxidx = maxl_loglrs.argmax() maxstats = vals[maxidx][1] maxphase = self.phase[maxidx] # set the stats for stat in maxstats: setattr(self._current_stats, stat, maxstats[stat]) self._current_stats.maxl_phase = maxphase self._current_stats.maxl_loglr = maxl_loglrs[maxidx] # calculate the marginal loglr and return return logsumexp(loglr) - numpy.log(len(self.phase))
[docs] class BruteLISASkyModesMarginalize(BaseGaussianNoise): name = "brute_lisa_sky_modes_marginalize" def __init__(self, variable_params, cores=1, loop_polarization=False, base_model=None, **kwds): super().__init__(variable_params, **kwds) from pycbc.inference.models import models kwds.update(models[base_model].extra_args_from_config( kwds['config_object'], "model", skip_args=[]) ) self.model = models[base_model](variable_params, **kwds) self.call = likelihood_wrapper(self.model) # size of pool for each likelihood call if cores > 1: self.pool = Pool(int(cores)) self.mapfunc = self.pool.map else: self.pool = None self.mapfunc = map # Do I explicitly check the polarization + pi/2 points # We could also add other arguments here, ie only check longitude # or latitude symmetry points. if loop_polarization: self.num_sky_modes = 16 else: self.num_sky_modes = 8 self.reconstruct_sky_points = False @property def _extra_stats(self): stats = self.model._extra_stats return stats def _loglr(self): params = [] for sym_num in range(self.num_sky_modes): pref = self.current_params.copy() self._apply_sky_point_rotation(pref, sym_num) params.append(pref) vals = list(self.mapfunc(self.call, params)) loglr = numpy.array([v[0] for v in vals]) if self.reconstruct_sky_points: return loglr max_llr_idx = loglr.argmax() max_llr = loglr[max_llr_idx] marg_lrfac = sum([math.exp(llr - max_llr) for llr in loglr]) marg_llr = max_llr + math.log(marg_lrfac/self.num_sky_modes) # set the stats for sym_num in range(self.num_sky_modes): setattr(self._current_stats, f'llr_mode_{sym_num}', loglr[sym_num]) return marg_llr def _apply_sky_point_rotation(self, pref, sky_num): """ Apply the sky point rotation for mode sky_num to parameters pref """ lambdal = pref['eclipticlongitude'] beta = pref['eclipticlatitude'] psi = pref['polarization'] inc = pref['inclination'] pol_num = sky_num // 8 sky_num = sky_num % 8 long_num = sky_num % 4 lat_num = sky_num // 4 # Apply latitude symmetry mode if lat_num: beta = - beta inc = numpy.pi - inc psi = numpy.pi - psi # Apply longitudonal symmetry mode lambdal = (lambdal + long_num * 0.5 * numpy.pi) % (2*numpy.pi) psi = (psi + long_num * 0.5 * numpy.pi) % (2*numpy.pi) # Apply additional polarization mode (shouldn't be needed) if pol_num: psi = psi + (math.pi / 2.) pref['eclipticlongitude'] = lambdal pref['eclipticlatitude'] = beta pref['polarization'] = psi pref['inclination'] = inc
[docs] @classmethod def from_config(cls, cp, **kwargs): kwargs['config_object'] = cp return super(BruteLISASkyModesMarginalize, cls).from_config( cp, **kwargs )
[docs] def reconstruct(self, seed=None): """ Reconstruct a point from unwrapping the 8-fold sky symmetry """ if seed: numpy.random.seed(seed) rec = {} logging.info('Reconstruct LISA sky mode symmetry') self.reconstruct_sky_points = True loglr = self.loglr xl = draw_sample(loglr) logging.info('Found point %d', xl) # Undo rotations pref = self.current_params.copy() self._apply_sky_point_rotation(pref, xl) for val in ['polarization', 'eclipticlongitude', 'eclipticlatitude', 'inclination']: rec[val] = pref[val] rec['loglr'] = loglr[xl] rec['loglikelihood'] = self.lognl + rec['loglr'] self.reconstruct_sky_points = False return self.model.reconstruct(seed=seed, rec=rec)