Source code for pycbc.inference.io.base_mcmc

# Copyright (C) 2016 Christopher M. Biwer, Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# self.option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Provides I/O that is specific to MCMC samplers.
"""


import numpy
import argparse


[docs]class CommonMCMCMetadataIO(object): """Provides functions for reading/writing MCMC metadata to file. The functions here are common to both standard MCMC (in which chains are independent) and ensemble MCMC (in which chains/walkers share information). """
[docs] def write_resume_point(self): """Keeps a list of the number of iterations that were in a file when a run was resumed from a checkpoint.""" try: resume_pts = self.attrs["resume_points"].tolist() except KeyError: resume_pts = [] try: niterations = self.niterations except KeyError: niterations = 0 resume_pts.append(niterations) self.attrs["resume_points"] = resume_pts
[docs] def write_niterations(self, niterations): """Writes the given number of iterations to the sampler group.""" self[self.sampler_group].attrs['niterations'] = niterations
@property def niterations(self): """Returns the number of iterations the sampler was run for.""" return self[self.sampler_group].attrs['niterations'] @property def nwalkers(self): """Returns the number of walkers used by the sampler. Alias of ``nchains``. """ try: return self[self.sampler_group].attrs['nwalkers'] except KeyError: return self[self.sampler_group].attrs['nchains'] @property def nchains(self): """Returns the number of chains used by the sampler. Alias of ``nwalkers``. """ try: return self[self.sampler_group].attrs['nchains'] except KeyError: return self[self.sampler_group].attrs['nwalkers'] def _thin_data(self, group, params, thin_interval): """Thins data on disk by the given interval. This makes no effort to record the thinning interval that is applied. Parameters ---------- group : str The group where the datasets to thin live. params : list The list of dataset names to thin. thin_interval : int The interval to thin the samples on disk by. """ samples = self.read_raw_samples(params, thin_start=0, thin_interval=thin_interval, thin_end=None, flatten=False, group=group) # now resize and write the data back to disk fpgroup = self[group] for param in params: data = samples[param] # resize the arrays on disk fpgroup[param].resize(data.shape) # and write fpgroup[param][:] = data
[docs] def thin(self, thin_interval): """Thins the samples on disk to the given thinning interval. The interval must be a multiple of the file's current ``thinned_by``. Parameters ---------- thin_interval : int The interval the samples on disk should be thinned by. """ # get the new interval to thin by new_interval = thin_interval / self.thinned_by if new_interval % 1: raise ValueError("thin interval ({}) must be a multiple of the " "current thinned_by ({})" .format(thin_interval, self.thinned_by)) new_interval = int(new_interval) # now thin the data on disk params = list(self[self.samples_group].keys()) self._thin_data(self.samples_group, params, new_interval) # store the interval that samples were thinned by self.thinned_by = thin_interval
@property def thinned_by(self): """Returns interval samples have been thinned by on disk. This looks for ``thinned_by`` in the samples group attrs. If none is found, will just return 1. """ try: thinned_by = self.attrs['thinned_by'] except KeyError: thinned_by = 1 return thinned_by @thinned_by.setter def thinned_by(self, thinned_by): """Sets the thinned_by attribute. This is the interval that samples have been thinned by on disk. The given value is written to ``self[self.samples_group].attrs['thinned_by']``. """ self.attrs['thinned_by'] = int(thinned_by)
[docs] def last_iteration(self, parameter=None, group=None): """Returns the iteration of the last sample of the given parameter. Parameters ---------- parameter : str, optional The name of the parameter to get the last iteration for. If None provided, will just use the first parameter in ``group``. group : str, optional The name of the group to get the last iteration from. Default is the ``samples_group``. """ if group is None: group = self.samples_group if parameter is None: try: parameter = list(self[group].keys())[0] except (IndexError, KeyError): # nothing has been written yet, just return 0 return 0 try: lastiter = self[group][parameter].shape[-1] except KeyError: # no samples have been written, just return 0 lastiter = 0 # account for thinning return lastiter * self.thinned_by
[docs] def iterations(self, parameter): """Returns the iteration each sample occurred at.""" return numpy.arange(0, self.last_iteration(parameter), self.thinned_by)
[docs] def write_sampler_metadata(self, sampler): """Writes the sampler's metadata.""" self.attrs['sampler'] = sampler.name try: self[self.sampler_group].attrs['nchains'] = sampler.nchains except ValueError: self[self.sampler_group].attrs['nwalkers'] = sampler.nwalkers # write the model's metadata sampler.model.write_metadata(self)
@property def is_burned_in(self): """Returns whether or not chains are burned in. Raises a ``ValueError`` if no burn in tests were done. """ try: return self[self.sampler_group]['is_burned_in'][()] except KeyError: raise ValueError("No burn in tests were performed") @property def burn_in_iteration(self): """Returns the burn in iteration of all the chains. Raises a ``ValueError`` if no burn in tests were done. """ try: return self[self.sampler_group]['burn_in_iteration'][()] except KeyError: raise ValueError("No burn in tests were performed") @property def burn_in_index(self): """Returns the burn in index. This is the burn in iteration divided by the file's ``thinned_by``. Requires the class that this is used with has a ``burn_in_iteration`` attribute. """ return self.burn_in_iteration // self.thinned_by @property def act(self): """The autocorrelation time (ACT). This is the ACL times the file's thinned by. Raises a ``ValueError`` if the ACT has not been calculated. """ try: return self[self.sampler_group]['act'][()] except KeyError: raise ValueError("ACT has not been calculated") @act.setter def act(self, act): """Writes the autocorrelation time(s). ACT(s) are written to the ``sample_group`` as a dataset with name ``act``. Parameters ---------- act : array or int ACT(s) to write. """ # pylint: disable=no-member self.write_data('act', act, path=self.sampler_group) @property def raw_acts(self): """Dictionary of parameter names -> raw autocorrelation time(s). Depending on the sampler, the autocorrelation times may be floats, or [ntemps x] [nchains x] arrays. Raises a ``ValueError`` is no raw acts have been set. """ try: group = self[self.sampler_group]['raw_acts'] except KeyError: raise ValueError("ACTs have not been calculated") acts = {} for param in group: acts[param] = group[param][()] return acts @raw_acts.setter def raw_acts(self, acts): """Writes the raw autocorrelation times. The ACT of each parameter is saved to ``[sampler_group]/raw_acts/{param}']``. Works for all types of MCMC samplers (independent chains, ensemble, parallel tempering). Parameters ---------- acts : dict A dictionary of ACTs keyed by the parameter. """ path = self.sampler_group + '/raw_acts' for param in acts: self.write_data(param, acts[param], path=path) @property def acl(self): """The autocorrelation length (ACL) of the samples. This is the autocorrelation time (ACT) divided by the file's ``thinned_by`` attribute. Raises a ``ValueError`` if the ACT has not been calculated. """ return self.act / self.thinned_by @acl.setter def acl(self, acl): """Sets the autocorrelation length (ACL) of the samples. This will convert the given value(s) to autocorrelation time(s) and save to the ``act`` attribute; see that attribute for details. """ self.act = acl * self.thinned_by @property def raw_acls(self): """Dictionary of parameter names -> raw autocorrelation length(s). Depending on the sampler, the autocorrelation lengths may be floats, or [ntemps x] [nchains x] arrays. The ACLs are the autocorrelation times (ACT) divided by the file's ``thinned_by`` attribute. Raises a ``ValueError`` is no raw acts have been set. """ return {p: self.raw_acts[p] / self.thinned_by for p in self.raw_acts} @raw_acls.setter def raw_acls(self, acls): """Sets the raw autocorrelation lengths. The given ACLs are converted to autocorrelation times (ACTs) and saved to the ``raw_acts`` attribute; see that attribute for details. Parameters ---------- acls : dict A dictionary of ACLs keyed by the parameter. """ self.raw_acts = {p: acls[p] * self.thinned_by for p in acls} def _update_sampler_history(self): """Writes the number of iterations, effective number of samples, autocorrelation times, and burn-in iteration to the history. """ path = '/'.join([self.sampler_group, 'checkpoint_history']) # write the current number of iterations self.write_data('niterations', self.niterations, path=path, append=True) self.write_data('effective_nsamples', self.effective_nsamples, path=path, append=True) # write the act: we'll make sure that this is 2D, so that the acts # can be appened along the last dimension try: act = self.act except ValueError: # no acts were calculate act = None if act is not None: act = act.reshape(tuple(list(act.shape)+[1])) self.write_data('act', act, path=path, append=True) # write the burn in iteration in the same way try: burn_in = self.burn_in_iteration except ValueError: # no burn in tests were done burn_in = None if burn_in is not None: burn_in = burn_in.reshape(tuple(list(burn_in.shape)+[1])) self.write_data('burn_in_iteration', burn_in, path=path, append=True)
[docs] @staticmethod def extra_args_parser(parser=None, skip_args=None, **kwargs): """Create a parser to parse sampler-specific arguments for loading samples. Parameters ---------- parser : argparse.ArgumentParser, optional Instead of creating a parser, add arguments to the given one. If none provided, will create one. skip_args : list, optional Don't parse the given options. Options should be given as the option string, minus the '--'. For example, ``skip_args=['iteration']`` would cause the ``--iteration`` argument not to be included. \**kwargs : All other keyword arguments are passed to the parser that is created. Returns ------- parser : argparse.ArgumentParser An argument parser with th extra arguments added. actions : list of argparse.Action A list of the actions that were added. """ if parser is None: parser = argparse.ArgumentParser(**kwargs) elif kwargs: raise ValueError("No other keyword arguments should be provded if " "a parser is provided.") if skip_args is None: skip_args = [] actions = [] if 'thin-start' not in skip_args: act = parser.add_argument( "--thin-start", type=int, default=None, help="Sample number to start collecting samples. If " "none provided, will use the input file's `thin_start` " "attribute.") actions.append(act) if 'thin-interval' not in skip_args: act = parser.add_argument( "--thin-interval", type=int, default=None, help="Interval to use for thinning samples. If none provided, " "will use the input file's `thin_interval` attribute.") actions.append(act) if 'thin-end' not in skip_args: act = parser.add_argument( "--thin-end", type=int, default=None, help="Sample number to stop collecting samples. If " "none provided, will use the input file's `thin_end` " "attribute.") actions.append(act) if 'iteration' not in skip_args: act = parser.add_argument( "--iteration", type=int, default=None, help="Only retrieve the given iteration. To load " "the last n-th sampe use -n, e.g., -1 will " "load the last iteration. This overrides " "the thin-start/interval/end options.") actions.append(act) if 'walkers' not in skip_args and 'chains' not in skip_args: act = parser.add_argument( "--walkers", "--chains", type=int, nargs="+", default=None, help="Only retrieve samples from the listed " "walkers/chains. Default is to retrieve from all " "walkers/chains.") actions.append(act) return parser, actions
[docs]class MCMCMetadataIO(object): """Provides functions for reading/writing metadata to file for MCMCs in which all chains are independent of each other. Overrides the ``BaseInference`` file's ``thin_start`` and ``thin_interval`` attributes. Instead of integers, these return arrays. """ @property def thin_start(self): """Returns the default thin start to use for reading samples. If burn-in tests were done, this will return the burn-in index of every chain that has burned in. The start index for chains that have not burned in will be greater than the number of samples, so that those chains return no samples. If no burn-in tests were done, returns 0 for all chains. """ # pylint: disable=no-member try: thin_start = self.burn_in_index # replace any that have not been burned in with the number # of iterations; this will cause those chains to not return # any samples thin_start[~self.is_burned_in] = \ int(numpy.ceil(self.niterations/self.thinned_by)) return thin_start except ValueError: # no burn in, just return array of zeros return numpy.zeros(self.nchains, dtype=int) @property def thin_interval(self): """Returns the default thin interval to use for reading samples. If a finite ACL exists in the file, will return that. Otherwise, returns 1. """ try: acl = self.acl except ValueError: return numpy.ones(self.nchains, dtype=int) # replace any infs with the number of samples acl[numpy.isinf(acl)] = self.niterations / self.thinned_by return numpy.ceil(acl).astype(int)
[docs]class EnsembleMCMCMetadataIO(object): """Provides functions for reading/writing metadata to file for ensemble MCMCs. """ @property def thin_start(self): """Returns the default thin start to use for reading samples. If burn-in tests were done, returns the burn in index. Otherwise, returns 0. """ try: return self.burn_in_index except ValueError: # no burn in, just return 0 return 0 @property def thin_interval(self): """Returns the default thin interval to use for reading samples. If a finite ACL exists in the file, will return that. Otherwise, returns 1. """ try: acl = self.acl except ValueError: acl = 1 if numpy.isfinite(acl): acl = int(numpy.ceil(acl)) else: acl = 1 return acl
[docs]def write_samples(fp, samples, parameters=None, last_iteration=None, samples_group=None, thin_by=None): """Writes samples to the given file. This works for both standard MCMC and ensemble MCMC samplers without parallel tempering. Results are written to ``samples_group/{vararg}``, where ``{vararg}`` is the name of a model params. The samples are written as an ``nwalkers x niterations`` array. If samples already exist, the new samples are appended to the current. If the current samples on disk have been thinned (determined by the ``thinned_by`` attribute in the samples group), then the samples will be thinned by the same amount before being written. The thinning is started at the sample in ``samples`` that occured at the iteration equal to the last iteration on disk plus the ``thinned_by`` interval. If this iteration is larger than the iteration of the last given sample, then none of the samples will be written. Parameters ----------- fp : BaseInferenceFile Open file handler to write files to. Must be an instance of BaseInferenceFile with CommonMCMCMetadataIO methods added. samples : dict The samples to write. Each array in the dictionary should have shape nwalkers x niterations. parameters : list, optional Only write the specified parameters to the file. If None, will write all of the keys in the ``samples`` dict. last_iteration : int, optional The iteration of the last sample. If the file's ``thinned_by`` attribute is > 1, this is needed to determine where to start thinning the samples such that the interval between the last sample currently on disk and the first new sample is the same as all of the other samples. samples_group : str, optional Which group to write the samples to. Default (None) will result in writing to "samples". thin_by : int, optional Override the ``thinned_by`` attribute in the file with the given value. **Only set this if you are using this function to write something other than inference samples!** """ nwalkers, nsamples = list(samples.values())[0].shape assert all(p.shape == (nwalkers, nsamples) for p in samples.values()), ( "all samples must have the same shape") if samples_group is None: samples_group = fp.samples_group if parameters is None: parameters = samples.keys() # thin the samples samples = thin_samples_for_writing(fp, samples, parameters, last_iteration, samples_group, thin_by=thin_by) # loop over number of dimensions group = samples_group + '/{name}' for param in parameters: dataset_name = group.format(name=param) data = samples[param] # check that there's something to write after thinning if data.shape[1] == 0: # nothing to write, move along continue try: fp_nsamples = fp[dataset_name].shape[-1] istart = fp_nsamples istop = istart + data.shape[1] if istop > fp_nsamples: # resize the dataset fp[dataset_name].resize(istop, axis=1) except KeyError: # dataset doesn't exist yet istart = 0 istop = istart + data.shape[1] fp.create_dataset(dataset_name, (nwalkers, istop), maxshape=(nwalkers, None), dtype=data.dtype, fletcher32=True) fp[dataset_name][:, istart:istop] = data
[docs]def ensemble_read_raw_samples(fp, fields, thin_start=None, thin_interval=None, thin_end=None, iteration=None, walkers=None, flatten=True, group=None): """Base function for reading samples from ensemble MCMC files without parallel tempering. Parameters ----------- fp : BaseInferenceFile Open file handler to write files to. Must be an instance of BaseInferenceFile with EnsembleMCMCMetadataIO methods added. fields : list The list of field names to retrieve. thin_start : int, optional Start reading from the given iteration. Default is to start from the first iteration. thin_interval : int, optional Only read every ``thin_interval`` -th sample. Default is 1. thin_end : int, optional Stop reading at the given iteration. Default is to end at the last iteration. iteration : int, optional Only read the given iteration. If this provided, it overrides the ``thin_(start|interval|end)`` options. walkers : (list of) int, optional Only read from the given walkers. Default (``None``) is to read all. flatten : bool, optional Flatten the samples to 1D arrays before returning. Otherwise, the returned arrays will have shape (requested walkers x requested iteration(s)). Default is True. group : str, optional The name of the group to read sample datasets from. Default is the file's ``samples_group``. Returns ------- dict A dictionary of field name -> numpy array pairs. """ if isinstance(fields, str): fields = [fields] # walkers to load widx, nwalkers = _ensemble_get_walker_index(fp, walkers) # get the slice to use get_index = _ensemble_get_index(fp, thin_start, thin_interval, thin_end, iteration) # load if group is None: group = fp.samples_group group = group + '/{name}' arrays = {} for name in fields: arr = fp[group.format(name=name)][widx, get_index] niterations = arr.shape[-1] if iteration is None else 1 if flatten: arr = arr.flatten() else: # ensure that the returned array is 2D arr = arr.reshape((nwalkers, niterations)) arrays[name] = arr return arrays
def _ensemble_get_walker_index(fp, walkers=None): """Convenience function to determine which walkers to load. Parameters ---------- fp : BaseInferenceFile Open file handler to write files to. Must be an instance of BaseInferenceFile with EnsembleMCMCMetadataIO methods added. walkers : (list of) int, optional Only read from the given walkers. Default (``None``) is to read all. Returns ------- widx : array or slice The walker indices to load. nwalkers : int The number of walkers that will be loaded. """ if walkers is not None: widx = numpy.zeros(fp.nwalkers, dtype=bool) widx[walkers] = True nwalkers = widx.sum() else: widx = slice(None, None) nwalkers = fp.nwalkers return widx, nwalkers def _ensemble_get_index(fp, thin_start=None, thin_interval=None, thin_end=None, iteration=None): """Determines the sample indices to retrieve for an ensemble MCMC. Parameters ----------- fp : BaseInferenceFile Open file handler to write files to. Must be an instance of BaseInferenceFile with EnsembleMCMCMetadataIO methods added. thin_start : int, optional Start reading from the given iteration. Default is to start from the first iteration. thin_interval : int, optional Only read every ``thin_interval`` -th sample. Default is 1. thin_end : int, optional Stop reading at the given iteration. Default is to end at the last iteration. iteration : int, optional Only read the given iteration. If this provided, it overrides the ``thin_(start|interval|end)`` options. Returns ------- slice or int The indices to retrieve. """ if iteration is not None: get_index = int(iteration) else: if thin_start is None: thin_start = fp.thin_start if thin_interval is None: thin_interval = fp.thin_interval if thin_end is None: thin_end = fp.thin_end get_index = fp.get_slice(thin_start=thin_start, thin_interval=thin_interval, thin_end=thin_end) return get_index def _get_index(fp, chains, thin_start=None, thin_interval=None, thin_end=None, iteration=None): """Determines the sample indices to retrieve for an MCMC with independent chains. Parameters ----------- fp : BaseInferenceFile Open file handler to read samples from. Must be an instance of BaseInferenceFile with EnsembleMCMCMetadataIO methods added. chains : array of int The chains to load. thin_start : array or int, optional Start reading from the given sample. May either provide an array indicating the start index for each chain, or an integer. If the former, the array must have the same length as the number of chains that will be retrieved. If the latter, the given value will be used for all chains. Default (None) is to use the file's ``thin_start`` attribute. thin_interval : array or int, optional Only read every ``thin_interval``-th sample. May either provide an array indicating the interval to use for each chain, or an integer. If the former, the array must have the same length as the number of chains that will be retrieved. If the latter, the given value will be used for all chains. Default (None) is to use the file's ``thin_interval`` attribute. thin_end : array or int, optional Stop reading at the given sample index. May either provide an array indicating the end index to use for each chain, or an integer. If the former, the array must have the same length as the number of chains that will be retrieved. If the latter, the given value will be used for all chains. Default (None) is to use the the file's ``thin_end`` attribute. iteration : int, optional Only read the given iteration from all chains. If provided, it overrides the ``thin_(start|interval|end)`` options. Returns ------- get_index : list of slice or int The indices to retrieve. """ nchains = len(chains) # convenience function to get the right thin start/interval/end if iteration is not None: get_index = [int(iteration)]*nchains else: # get the slice arguments thin_start = _format_slice_arg(thin_start, fp.thin_start, chains) thin_interval = _format_slice_arg(thin_interval, fp.thin_interval, chains) thin_end = _format_slice_arg(thin_end, fp.thin_end, chains) # the slices to use for each chain get_index = [fp.get_slice(thin_start=thin_start[ci], thin_interval=thin_interval[ci], thin_end=thin_end[ci]) for ci in range(nchains)] return get_index def _format_slice_arg(value, default, chains): """Formats a start/interval/end argument for picking out chains. Parameters ---------- value : None, int, array or list of int The thin-start/interval/end value to format. ``None`` indicates the user did not specify anything, in which case ``default`` will be used. If an integer, then it will be repeated to match the length of ``chains```. If an array or list, it must have the same length as ``chains``. default : array What to use instead if ``value`` is ``None``. chains : array of int The index values of chains that will be loaded. Returns ------- array Array giving the value to use for each chain in ``chains``. The array will have the same length as ``chains``. """ if value is None and default is None: # no value provided, and default is None, just return Nones with the # same length as chains value = [None]*len(chains) elif value is None: # use the default, with the desired values extracted value = default[chains] elif isinstance(value, (int, numpy.int_)): # a single integer was provided, repeat into an array value = numpy.repeat(value, len(chains)) elif len(value) != len(chains): # a list of values was provided, but the length does not match the # chains, raise an error raise ValueError("Number of requested thin-start/interval/end values " "({}) does not match number of requested chains ({})" .format(len(value), len(chains))) return value
[docs]def thin_samples_for_writing(fp, samples, parameters, last_iteration, group, thin_by=None): """Thins samples for writing to disk. The thinning interval to use is determined by the given file handler's ``thinned_by`` attribute. If that attribute is 1, just returns the samples. Parameters ---------- fp : CommonMCMCMetadataIO instance The file the sampels will be written to. Needed to determine the thin interval used on disk. samples : dict Dictionary mapping parameter names to arrays of (unthinned) samples. The arrays are thinned along their last dimension. parameters : list of str The parameters to thin in ``samples`` before writing. All listed parameters must be in ``samples``. last_iteration : int The iteration that the last sample in ``samples`` occurred at. This is needed to figure out where to start the thinning in ``samples``, such that the interval between the last sample on disk and the first new sample is the same as all of the other samples. group : str The name of the group that the samples will be written to. This is needed to determine what the last iteration saved on disk was. thin_by : int, optional Override the ``thinned_by`` attribute in the file for with the given value. **Only do this if you are thinning something other than inference samples!** Returns ------- dict : Dictionary of the thinned samples to write. """ if thin_by is None: thin_by = fp.thinned_by if thin_by > 1: if last_iteration is None: raise ValueError("File's thinned_by attribute is > 1 ({}), " "but last_iteration not provided." .format(thin_by)) thinned_samples = {} for param in parameters: data = samples[param] nsamples = data.shape[-1] # To figure out where to start: # the last iteration in the file + the file's thinning interval # gives the iteration of the next sample that should be written; # last_iteration - nsamples gives the iteration of the first # sample in samples. Subtracting the latter from the former - 1 # (-1 to convert from iteration to index) therefore gives the index # in the samples data to start using samples. thin_start = fp.last_iteration(param, group) + thin_by \ - (last_iteration - nsamples) - 1 thinned_samples[param] = data[..., thin_start::thin_by] else: thinned_samples = samples return thinned_samples
[docs]def nsamples_in_chain(start_iter, interval, niterations): """Calculates the number of samples in an MCMC chain given a thinning start, end, and interval. This function will work with either python scalars, or numpy arrays. Parameters ---------- start_iter : (array of) int Start iteration. If negative, will count as being how many iterations to start before the end; otherwise, counts how many iterations to start before the beginning. If this is larger than niterations, will just return 0. interval : (array of) int Thinning interval. niterations : (array of) int The number of iterations. Returns ------- num_samples : (array of) numpy.int The number of samples in a chain, >= 0. """ # this is written in a slightly wonky way so that it will work with either # python scalars or numpy arrays; it is equivalent to: # if start_iter < 0: # count = min(abs(start_iter), niterations) # else: # count = max(niterations - start_iter, 0) slt0 = start_iter < 0 sgt0 = start_iter >= 0 count = slt0*abs(start_iter) + sgt0*(niterations - start_iter) # ensure count is in [0, niterations] cgtn = count > niterations cok = (count >= 0) & (count <= niterations) count = cgtn*niterations + cok*count return numpy.ceil(count / interval).astype(int)