# Copyright (C) 2022 Gareth Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# =============================================================================
#
# Preamble
#
# =============================================================================
#
"""
This module contains functions for reading in command line options and
applying cuts to triggers or templates in the offline search
"""
import logging
import copy
import numpy as np
from pycbc.events import ranking
from pycbc.io import hdf
from pycbc.tmpltbank import bank_conversions as bank_conv
from pycbc.io import get_chisq_from_file_choice
# Only used to check isinstance:
from pycbc.io.hdf import ReadByTemplate
logger = logging.getLogger('pycbc.events.cuts')
# sngl_rank_keys are the allowed names of reweighted SNR functions
sngl_rank_keys = ranking.sngls_ranking_function_dict.keys()
trigger_param_choices = list(sngl_rank_keys)
trigger_param_choices += [cc + '_chisq' for cc in hdf.chisq_choices]
trigger_param_choices += ['end_time', 'psd_var_val', 'sigmasq',
'sigma_multiple']
template_fit_param_choices = ['fit_by_fit_coeff', 'smoothed_fit_coeff',
'fit_by_count_above_thresh',
'smoothed_fit_count_above_thresh',
'fit_by_count_in_template',
'smoothed_fit_count_in_template']
template_param_choices = bank_conv.conversion_options + \
template_fit_param_choices
# What are the inequalities associated with the cuts?
# 'upper' means upper limit, and so requires value < threshold
# to keep a trigger
ineq_functions = {
'upper': np.less,
'lower': np.greater,
'upper_inc': np.less_equal,
'lower_inc': np.greater_equal
}
ineq_choices = list(ineq_functions.keys())
[docs]
def insert_cuts_option_group(parser):
"""
Add options to the parser for cuts to the templates/triggers
"""
parser.add_argument('--trigger-cuts', nargs='+',
help="Cuts to apply to the triggers, supplied as "
"PARAMETER:VALUE:LIMIT, where, PARAMETER is the "
"parameter to be cut, VALUE is the value at "
"which it is cut, and LIMIT is one of '"
+ "', '".join(ineq_choices) +
"' to indicate the inequality needed. "
"PARAMETER is one of:'"
+ "', '".join(trigger_param_choices) +
"'. For example snr:6:LOWER removes triggers "
"with matched filter SNR < 6")
parser.add_argument('--template-cuts', nargs='+',
help="Cuts to apply to the triggers, supplied as "
"PARAMETER:VALUE:LIMIT. Format is the same as in "
"--trigger-cuts. PARAMETER can be one of '"
+ "', '".join(template_param_choices) + "'.")
[docs]
def check_update_cuts(cut_dict, new_cut):
"""
Update a cuts dictionary, but check whether the cut exists already,
warn and only apply the strictest cuts
Parameters
----------
cut_dict: dictionary
Dictionary containing the cuts to be checked, will be updated
new_cut: single-entry dictionary
dictionary to define the new cut which is being considered to add
"""
new_cut_key = list(new_cut.keys())[0]
if new_cut_key in cut_dict:
# The cut has already been called
logger.warning("WARNING: Cut parameter %s and function %s have "
"already been used. Utilising the strictest cut.",
new_cut_key[0], new_cut_key[1].__name__)
# Extract the function and work out which is strictest
cut_function = new_cut_key[1]
value_new = list(new_cut.values())[0]
value_old = cut_dict[new_cut_key]
if cut_function(value_new, value_old):
# The new threshold would survive the cut of the
# old threshold, therefore the new threshold is stricter
# - update it
logger.warning("WARNING: New threshold of %.3f is "
"stricter than old threshold %.3f, "
"using cut at %.3f.",
value_new, value_old, value_new)
cut_dict.update(new_cut)
else:
# New cut would not make a difference, ignore it
logger.warning("WARNING: New threshold of %.3f is less "
"strict than old threshold %.3f, using "
"cut at %.3f.",
value_new, value_old, value_old)
else:
# This is a new cut - add it
cut_dict.update(new_cut)
[docs]
def ingest_cuts_option_group(args):
"""
Return dictionaries for trigger and template cuts.
"""
# Deal with the case where no cuts are supplied:
if not args.trigger_cuts and not args.template_cuts:
return {}, {}
# Deal with the case where one set of cuts is supplied
# but not the other
trigger_cut_strs = args.trigger_cuts or []
template_cut_strs = args.template_cuts or []
# Handle trigger cuts
trigger_cut_dict = {}
for inputstr in trigger_cut_strs:
new_trigger_cut = convert_inputstr(inputstr, trigger_param_choices)
check_update_cuts(trigger_cut_dict, new_trigger_cut)
# Handle template cuts
template_cut_dict = {}
for inputstr in template_cut_strs:
new_template_cut = convert_inputstr(inputstr, template_param_choices)
check_update_cuts(template_cut_dict, new_template_cut)
return trigger_cut_dict, template_cut_dict
[docs]
def sigma_multiple_cut_thresh(template_ids, statistic,
cut_thresh, ifo):
"""
Apply cuts based on a multiple of the median sigma value for the template
Parameters
----------
template_ids:
template_id values for each of the triggers to be considered,
this will be used to associate a sigma threshold for each trigger
statistic:
A PyCBC ranking statistic instance. Used to get the median_sigma
value for the cuts. If fits_by_tid does not exist for the specified
ifo (where median_sigma lives), an error will be raised.
ifo:
The IFO for which we want to read median_sigma
cut_thresh: int or float
The multiple of median_sigma to compare triggers to
Returns
-------
idx_out: numpy array
An array of the indices of triggers which meet the criteria
set by the dictionary
"""
statistic_classname = statistic.__class__.__name__
if not hasattr(statistic, 'fits_by_tid'):
raise ValueError("Cut parameter 'sigma_muliple' cannot "
"be used when the ranking statistic " +
statistic_classname + " does not use "
"template fitting.")
tid_med_sigma = statistic.fits_by_tid[ifo]['median_sigma']
return cut_thresh * tid_med_sigma[template_ids]
[docs]
def apply_trigger_cuts(triggers, trigger_cut_dict, statistic=None):
"""
Fetch/Calculate the parameter for triggers, and then
apply the cuts defined in template_cut_dict
Parameters
----------
triggers: ReadByTemplate object or dictionary
The triggers in this particular template. This
must have the correct datasets required to calculate
the values we cut on.
trigger_cut_dict: dictionary
Dictionary with tuples of (parameter, cut_function)
as keys, cut_thresholds as values
made using ingest_cuts_option_group function
Returns
-------
idx_out: numpy array
An array of the indices which meet the criteria
set by the dictionary
"""
idx_out = np.arange(len(triggers['snr']))
# Loop through the different cuts, and apply them
for parameter_cut_function, cut_thresh in trigger_cut_dict.items():
# The function and threshold are stored as a tuple so unpack it
parameter, cut_function = parameter_cut_function
# What kind of parameter is it?
if parameter.endswith('_chisq'):
# parameter is a chisq-type thing
chisq_choice = parameter.split('_')[0]
# Currently calculated for all triggers - this seems inefficient
value = get_chisq_from_file_choice(triggers, chisq_choice)
# Apply any previous cuts to the value for comparison
value = value[idx_out]
elif parameter == "sigma_multiple":
if isinstance(triggers, ReadByTemplate):
ifo_grp = triggers.file[triggers.ifo]
value = np.sqrt(ifo_grp['sigmasq'][idx_out])
template_ids = ifo_grp['template_id'][idx_out]
# Get a cut threshold value, this will be different
# depending on the template ID, so we rewrite cut_thresh
# as a value for each trigger, numpy comparison functions
# allow this
cut_thresh = sigma_multiple_cut_thresh(template_ids,
statistic,
cut_thresh,
triggers.ifo)
else:
err_msg = "Cuts on 'sigma_multiple' are only implemented for "
err_msg += "triggers in a ReadByTemplate format. This code "
err_msg += f"uses a {type(triggers).__name__} format."
raise NotImplementedError(err_msg)
elif ((not hasattr(triggers, "file") and parameter in triggers)
or (hasattr(triggers, "file")
and parameter in triggers.file[triggers.ifo])):
# parameter can be read direct from the trigger dictionary / file
if not hasattr(triggers, 'file') and parameter in triggers:
value = triggers[parameter]
else:
value = triggers.file[triggers.ifo][parameter]
# Apply any previous cuts to the value for comparison
value = value[idx_out]
elif parameter in sngl_rank_keys:
# parameter is a newsnr-type thing
# Currently calculated for all triggers - this seems inefficient
value = ranking.get_sngls_ranking_from_trigs(triggers, parameter)
# Apply any previous cuts to the value for comparison
value = value[idx_out]
else:
raise NotImplementedError("Parameter '" + parameter + "' not "
"recognised. Input sanitisation means "
"this shouldn't have happened?!")
idx_out = idx_out[cut_function(value, cut_thresh)]
return idx_out
[docs]
def apply_template_fit_cut(statistic, ifos, parameter_cut_function, cut_thresh,
template_ids):
"""
Apply cuts to template fit parameters, these have a few more checks
needed, so we separate out from apply_template_cuts defined later
Parameters
----------
statistic:
A PyCBC ranking statistic instance. Used for the template fit
cuts. If fits_by_tid does not exist for each ifo, then
template fit cuts will be skipped. If a fit cut has been specified
and fits_by_tid does not exist for all ifos, an error will be raised.
ifos: list of strings
List of IFOS used in this findtrigs instance.
Templates must pass cuts in all IFOs.
parameter_cut_function: tuple
First entry: Which parameter is being used for the cut?
Second entry: Cut function
cut_thresh: float or int
Cut threshold to the parameter according to the cut function
template_ids: numpy array
Array of template_ids which have passed previous cuts
Returns
-------
tids_out: numpy array
Array of template_ids which have passed this cut
"""
parameter, cut_function = parameter_cut_function
statistic_classname = statistic.__class__.__name__
# We can only apply template fit cuts if template fits have been done
if not hasattr(statistic, 'fits_by_tid'):
raise ValueError("Cut parameter " + parameter + " cannot "
"be used when the ranking statistic " +
statistic_classname + " does not use "
"template fitting.")
# Is the parameter actually in the fits dictionary?
if parameter not in statistic.fits_by_tid[ifos[0]]:
# Shouldn't get here due to input sanitisation
raise ValueError("Cut parameter " + parameter + " not "
"available in fits file.")
# Template IDs array to cut down in each IFO
tids_out = copy.copy(template_ids)
# Need to apply this cut to all IFOs
for ifo in ifos:
fits_dict = statistic.fits_by_tid[ifo]
values = fits_dict[parameter][tids_out]
# Only keep templates which pass this cut
tids_out = tids_out[cut_function(values, cut_thresh)]
return tids_out
[docs]
def apply_template_cuts(bank, template_cut_dict, template_ids=None,
statistic=None, ifos=None):
"""
Fetch/calculate the parameter for the templates, possibly already
preselected by template_ids, and then apply the cuts defined
in template_cut_dict
As this is used to select templates for use in findtrigs codes,
we remove anything which does not pass
Parameters
----------
bank: h5py File object, or a dictionary
Must contain the usual template bank datasets
template_cut_dict: dictionary
Dictionary with tuples of (parameter, cut_function)
as keys, cut_thresholds as values
made using ingest_cuts_option_group function
Optional Parameters
-------------------
template_ids: list of indices
Indices of templates to consider within the bank, useful if
templates have already been down-selected
statistic:
A PyCBC ranking statistic instance. Used for the template fit
cuts. If fits_by_tid does not exist for each ifo, then
template fit cuts will be skipped. If a fit cut has been specified
and fits_by_tid does not exist for all ifos, an error will be raised.
If not supplied, no template fit cuts will be attempted.
ifos: list of strings
List of IFOS used in this findtrigs instance.
Templates must pass cuts in all IFOs. This is important
e.g. for template fit parameter cuts.
Returns
-------
tids_out: numpy array
Array of template_ids which have passed all cuts
"""
# Get the initial list of templates:
tids_out = np.arange(bank['mass1'].size) \
if template_ids is None else template_ids[:]
if (statistic is None) ^ (ifos is None):
raise NotImplementedError("Either both or neither of statistic and "
"ifos must be supplied.")
if not template_cut_dict:
# No cuts are defined in the dictionary: just return the
# list of all tids
return tids_out
# Loop through the different cuts, and apply them
for parameter_cut_function, cut_thresh in template_cut_dict.items():
# The function and threshold are stored as a tuple so unpack it
parameter, cut_function = parameter_cut_function
if parameter in bank_conv.conversion_options:
# Calculate the parameter values using the bank property helper
values = bank_conv.get_bank_property(parameter, bank, tids_out)
# Only keep templates which pass this cut
tids_out = tids_out[cut_function(values, cut_thresh)]
elif parameter in template_fit_param_choices:
if statistic and ifos:
tids_out = apply_template_fit_cut(statistic,
ifos,
parameter_cut_function,
cut_thresh,
tids_out)
else:
raise ValueError("Cut parameter " + parameter + " not recognised."
" This shouldn't happen with input sanitisation")
return tids_out