Source code for pycbc.results.scatter_histograms

# Copyright (C) 2016 Miriam Cabero Mueller, Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# Public License for more details.
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
#                                   Preamble
# =============================================================================
Module to generate figures with scatter plots and histograms.

import itertools
import sys

import numpy

import scipy.stats

import matplotlib

# Only if a backend is not already set ... This should really *not* be done
# here, but in the executables you should set matplotlib.use()
# This matches the check that matplotlib does internally, but this *may* be
# version dependenant. If this is a problem then remove this and control from
# the executables directly.
if 'matplotlib.backends' not in sys.modules:  # nopep8

from matplotlib import (offsetbox, pyplot, gridspec, colors)

from pycbc.results import str_utils
from import FieldArray

[docs] def create_axes_grid(parameters, labels=None, height_ratios=None, width_ratios=None, no_diagonals=False): """Given a list of parameters, creates a figure with an axis for every possible combination of the parameters. Parameters ---------- parameters : list Names of the variables to be plotted. labels : {None, dict}, optional A dictionary of parameters -> parameter labels. height_ratios : {None, list}, optional Set the height ratios of the axes; see `matplotlib.gridspec.GridSpec` for details. width_ratios : {None, list}, optional Set the width ratios of the axes; see `matplotlib.gridspec.GridSpec` for details. no_diagonals : {False, bool}, optional Do not produce axes for the same parameter on both axes. Returns ------- fig : pyplot.figure The figure that was created. axis_dict : dict A dictionary mapping the parameter combinations to the axis and their location in the subplots grid; i.e., the key, values are: `{('param1', 'param2'): (pyplot.axes, row index, column index)}` """ if labels is None: labels = {p: p for p in parameters} elif any(p not in labels for p in parameters): raise ValueError("labels must be provided for all parameters") # Create figure with adequate size for number of parameters. ndim = len(parameters) if no_diagonals: ndim -= 1 if ndim < 3: fsize = (8, 7) else: fsize = (ndim*3 - 1, ndim*3 - 2) fig = pyplot.figure(figsize=fsize) # create the axis grid gs = gridspec.GridSpec(ndim, ndim, width_ratios=width_ratios, height_ratios=height_ratios, wspace=0.05, hspace=0.05) # create grid of axis numbers to easily create axes in the right locations axes = numpy.arange(ndim**2).reshape((ndim, ndim)) # Select possible combinations of plots and establish rows and columns. combos = list(itertools.combinations(parameters, 2)) # add the diagonals if not no_diagonals: combos += [(p, p) for p in parameters] # create the mapping between parameter combos and axes axis_dict = {} # cycle over all the axes, setting thing as needed for nrow in range(ndim): for ncolumn in range(ndim): ax = pyplot.subplot(gs[axes[nrow, ncolumn]]) # map to a parameter index px = parameters[ncolumn] if no_diagonals: py = parameters[nrow+1] else: py = parameters[nrow] if (px, py) in combos: axis_dict[px, py] = (ax, nrow, ncolumn) # x labels only on bottom if nrow + 1 == ndim: ax.set_xlabel('{}'.format(labels[px]), fontsize=18) else: pyplot.setp(ax.get_xticklabels(), visible=False) # y labels only on left if ncolumn == 0: ax.set_ylabel('{}'.format(labels[py]), fontsize=18) else: pyplot.setp(ax.get_yticklabels(), visible=False) else: # make non-used axes invisible ax.axis('off') return fig, axis_dict
[docs] def get_scale_fac(fig, fiducial_width=8, fiducial_height=7): """Gets a factor to scale fonts by for the given figure. The scale factor is relative to a figure with dimensions (`fiducial_width`, `fiducial_height`). """ width, height = fig.get_size_inches() return (width*height/(fiducial_width*fiducial_height))**0.5
[docs] def construct_kde(samples_array, use_kombine=False, kdeargs=None): """Constructs a KDE from the given samples. Parameters ---------- samples_array : array Array of values to construct the KDE for. use_kombine : bool, optional Use kombine's clustered KDE instead of scipy's. Default is False. kdeargs : dict, optional Additional arguments to pass to the KDE. Can be any argument recognized by :py:func:`scipy.stats.gaussian_kde` or :py:func:`kombine.clustered_kde.optimized_kde`. In either case, you can also set ``max_kde_samples`` to limit the number of samples that are used for KDE construction. Returns ------- kde : The KDE. """ # make sure samples are randomly sorted numpy.random.seed(0) numpy.random.shuffle(samples_array) # if kde arg specifies a maximum number of samples, limit them if kdeargs is None: kdeargs = {} else: kdeargs = kdeargs.copy() max_nsamples = kdeargs.pop('max_kde_samples', None) samples_array = samples_array[:max_nsamples] if use_kombine: try: import kombine except ImportError: raise ImportError("kombine is not installed.") if kdeargs is None: kdeargs = {} # construct the kde if use_kombine: kde = kombine.clustered_kde.optimized_kde(samples_array, **kdeargs) else: kde = scipy.stats.gaussian_kde(samples_array.T, **kdeargs) return kde
[docs] def create_density_plot(xparam, yparam, samples, plot_density=True, plot_contours=True, percentiles=None, cmap='viridis', contour_color=None, label_contours=True, contour_linestyles=None, xmin=None, xmax=None, ymin=None, ymax=None, exclude_region=None, fig=None, ax=None, use_kombine=False, kdeargs=None): """Computes and plots posterior density and confidence intervals using the given samples. Parameters ---------- xparam : string The parameter to plot on the x-axis. yparam : string The parameter to plot on the y-axis. samples : dict, numpy structured array, or FieldArray The samples to plot. plot_density : {True, bool} Plot a color map of the density. plot_contours : {True, bool} Plot contours showing the n-th percentiles of the density. percentiles : {None, float or array} What percentile contours to draw. If None, will plot the 50th and 90th percentiles. cmap : {'viridis', string} The name of the colormap to use for the density plot. contour_color : {None, string} What color to make the contours. Default is white for density plots and black for other plots. label_contours : bool, optional Whether to label the contours. Default is True. contour_linestyles : list, optional Linestyles to use for the contours. Default (None) will use solid. xmin : {None, float} Minimum value to plot on x-axis. xmax : {None, float} Maximum value to plot on x-axis. ymin : {None, float} Minimum value to plot on y-axis. ymax : {None, float} Maximum value to plot on y-axis. exclue_region : {None, str} Exclude the specified region when plotting the density or contours. Must be a string in terms of `xparam` and `yparam` that is understandable by numpy's logical evaluation. For example, if `xparam = m_1` and `yparam = m_2`, and you want to exclude the region for which `m_2` is greater than `m_1`, then exclude region should be `'m_2 > m_1'`. fig : {None, pyplot.figure} Add the plot to the given figure. If None and ax is None, will create a new figure. ax : {None, pyplot.axes} Draw plot on the given axis. If None, will create a new axis from `fig`. use_kombine : {False, bool} Use kombine's KDE to calculate density. Otherwise, will use `scipy.stats.gaussian_kde.` Default is False. kdeargs : dict, optional Pass the given keyword arguments to the KDE. Returns ------- fig : pyplot.figure The figure the plot was made on. ax : pyplot.axes The axes the plot was drawn on. """ if percentiles is None: percentiles = numpy.array([50., 90.]) percentiles = 100. - numpy.array(percentiles) percentiles.sort() if ax is None and fig is None: fig = pyplot.figure() if ax is None: ax = fig.add_subplot(111) # convert samples to array and construct kde xsamples = samples[xparam] ysamples = samples[yparam] arr = numpy.vstack((xsamples, ysamples)).T kde = construct_kde(arr, use_kombine=use_kombine, kdeargs=kdeargs) # construct grid to evaluate on if xmin is None: xmin = xsamples.min() if xmax is None: xmax = xsamples.max() if ymin is None: ymin = ysamples.min() if ymax is None: ymax = ysamples.max() npts = 100 X, Y = numpy.mgrid[ xmin:xmax:complex(0, npts), # pylint:disable=invalid-slice-index ymin:ymax:complex(0, npts)] # pylint:disable=invalid-slice-index pos = numpy.vstack([X.ravel(), Y.ravel()]) if use_kombine: Z = numpy.exp(kde(pos.T).reshape(X.shape)) draw = kde.draw else: Z = kde(pos).T.reshape(X.shape) draw = kde.resample if exclude_region is not None: # convert X,Y to a single FieldArray so we can use it's ability to # evaluate strings farr = FieldArray.from_kwargs(**{xparam: X, yparam: Y}) Z[farr[exclude_region]] = 0. if plot_density: ax.imshow(numpy.rot90(Z), extent=[xmin, xmax, ymin, ymax], aspect='auto', cmap=cmap, zorder=1) if contour_color is None: contour_color = 'w' if plot_contours: # compute the percentile values resamps = kde(draw(int(npts**2))) if use_kombine: resamps = numpy.exp(resamps) s = numpy.percentile(resamps, percentiles) if contour_color is None: contour_color = 'k' # make linewidths thicker if not plotting density for clarity if plot_density: lw = 1 else: lw = 2 ct = ax.contour(X, Y, Z, s, colors=contour_color, linewidths=lw, linestyles=contour_linestyles, zorder=3) # label contours if label_contours: lbls = ['{p}%'.format(p=int(p)) for p in (100. - percentiles)] fmt = dict(zip(ct.levels, lbls)) fs = 12 ax.clabel(ct, ct.levels, inline=True, fmt=fmt, fontsize=fs) return fig, ax
[docs] def create_marginalized_hist(ax, values, label, percentiles=None, color='k', fillcolor='gray', linecolor='navy', linestyle='-', plot_marginal_lines=True, title=True, expected_value=None, expected_color='red', rotated=False, plot_min=None, plot_max=None, log_scale=False): """Plots a 1D marginalized histogram of the given param from the given samples. Parameters ---------- ax : pyplot.Axes The axes on which to draw the plot. values : array The parameter values to plot. label : str A label to use for the title. percentiles : {None, float or array} What percentiles to draw lines at. If None, will draw lines at `[5, 50, 95]` (i.e., the bounds on the upper 90th percentile and the median). color : {'k', string} What color to make the histogram; default is black. fillcolor : {'gray', string, or None} What color to fill the histogram with. Set to None to not fill the histogram. Default is 'gray'. plot_marginal_lines : bool, optional Put vertical lines at the marginal percentiles. Default is True. linestyle : str, optional What line style to use for the histogram. Default is '-'. linecolor : {'navy', string} What color to use for the percentile lines. Default is 'navy'. title : bool, optional Add a title with a estimated value +/- uncertainty. The estimated value is the pecentile halfway between the max/min of ``percentiles``, while the uncertainty is given by the max/min of the ``percentiles``. If no percentiles are specified, defaults to quoting the median +/- 95/5 percentiles. rotated : {False, bool} Plot the histogram on the y-axis instead of the x. Default is False. plot_min : {None, float} The minimum value to plot. If None, will default to whatever `pyplot` creates. plot_max : {None, float} The maximum value to plot. If None, will default to whatever `pyplot` creates. scalefac : {1., float} Factor to scale the default font sizes by. Default is 1 (no scaling). log_scale : boolean Should the histogram bins be logarithmically spaced """ if fillcolor is None: htype = 'step' else: htype = 'stepfilled' if rotated: orientation = 'horizontal' else: orientation = 'vertical' if log_scale: bins = numpy.logspace( numpy.log10(numpy.nanmin(values)), numpy.log10(numpy.nanmax(values)), 50 ) else: bins = numpy.linspace( numpy.nanmin(values), numpy.nanmax(values), 50, ) ax.hist(values, bins=bins, histtype=htype, orientation=orientation, facecolor=fillcolor, edgecolor=color, ls=linestyle, lw=2, density=True) if percentiles is None: percentiles = [5., 50., 95.] if len(percentiles) > 0: plotp = numpy.percentile(values, percentiles) else: plotp = [] if plot_marginal_lines: for val in plotp: if rotated: ax.axhline(y=val, ls='dashed', color=linecolor, lw=2, zorder=3) else: ax.axvline(x=val, ls='dashed', color=linecolor, lw=2, zorder=3) # plot expected if expected_value is not None: if rotated: ax.axhline(expected_value, color=expected_color, lw=1.5, zorder=2) else: ax.axvline(expected_value, color=expected_color, lw=1.5, zorder=2) if title: if len(percentiles) > 0: minp = min(percentiles) maxp = max(percentiles) medp = (maxp + minp) / 2. else: minp = 5 medp = 50 maxp = 95 values_min = numpy.percentile(values, minp) values_med = numpy.percentile(values, medp) values_max = numpy.percentile(values, maxp) negerror = values_med - values_min poserror = values_max - values_med fmt = '${0}$'.format(str_utils.format_value( values_med, negerror, plus_error=poserror)) if rotated: ax.yaxis.set_label_position("right") # sets colored title for marginal histogram set_marginal_histogram_title(ax, fmt, color, label=label, rotated=rotated) else: # sets colored title for marginal histogram set_marginal_histogram_title(ax, fmt, color, label=label) # remove ticks and set limits if rotated: # Remove x-ticks ax.set_xticks([]) # turn off x-labels ax.set_xlabel('') # set limits ymin, ymax = ax.get_ylim() if plot_min is not None: ymin = plot_min if plot_max is not None: ymax = plot_max ax.set_ylim(ymin, ymax) else: # Remove y-ticks ax.set_yticks([]) # turn off y-label ax.set_ylabel('') # set limits xmin, xmax = ax.get_xlim() if plot_min is not None: xmin = plot_min if plot_max is not None: xmax = plot_max ax.set_xlim(xmin, xmax)
[docs] def set_marginal_histogram_title(ax, fmt, color, label=None, rotated=False): """ Sets the title of the marginal histograms. Parameters ---------- ax : Axes The `Axes` instance for the plot. fmt : str The string to add to the title. color : str The color of the text to add to the title. label : str If title does not exist, then include label at beginning of the string. rotated : bool If `True` then rotate the text 270 degrees for sideways title. """ # get rotation angle of the title rotation = 270 if rotated else 0 # get how much to displace title on axes xscale = 1.05 if rotated else 0.0 if rotated: yscale = 1.0 elif len(ax.get_figure().axes) > 1: yscale = 1.15 else: yscale = 1.05 # get class that packs text boxes vertical or horizonitally packer_class = offsetbox.VPacker if rotated else offsetbox.HPacker # if no title exists if not hasattr(ax, "title_boxes"): # create a text box title = "{} = {}".format(label, fmt) tbox1 = offsetbox.TextArea( title, textprops=dict(color=color, size=15, rotation=rotation, ha='left', va='bottom')) # save a list of text boxes as attribute for later ax.title_boxes = [tbox1] # pack text boxes ybox = packer_class(children=ax.title_boxes, align="bottom", pad=0, sep=5) # else append existing title else: # delete old title ax.title_anchor.remove() # add new text box to list tbox1 = offsetbox.TextArea( " {}".format(fmt), textprops=dict(color=color, size=15, rotation=rotation, ha='left', va='bottom')) ax.title_boxes = ax.title_boxes + [tbox1] # pack text boxes ybox = packer_class(children=ax.title_boxes, align="bottom", pad=0, sep=5) # add new title and keep reference to instance as an attribute anchored_ybox = offsetbox.AnchoredOffsetbox( loc=2, child=ybox, pad=0., frameon=False, bbox_to_anchor=(xscale, yscale), bbox_transform=ax.transAxes, borderpad=0.) ax.title_anchor = ax.add_artist(anchored_ybox)
[docs] def create_multidim_plot(parameters, samples, labels=None, mins=None, maxs=None, expected_parameters=None, expected_parameters_color='r', plot_marginal=True, plot_scatter=True, plot_maxl=False, plot_marginal_lines=True, marginal_percentiles=None, contour_percentiles=None, marginal_title=True, marginal_linestyle='-', zvals=None, show_colorbar=True, cbar_label=None, vmin=None, vmax=None, scatter_cmap='plasma', scatter_log_cmap=False, log_parameters=None, plot_density=False, plot_contours=True, density_cmap='viridis', contour_color=None, label_contours=True, contour_linestyles=None, hist_color='black', line_color=None, fill_color='gray', use_kombine=False, kdeargs=None, fig=None, axis_dict=None): """Generate a figure with several plots and histograms. Parameters ---------- parameters: list Names of the variables to be plotted. samples : FieldArray A field array of the samples to plot. labels: dict, optional A dictionary mapping parameters to labels. If none provided, will just use the parameter strings as the labels. mins : {None, dict}, optional Minimum value for the axis of each variable in `parameters`. If None, it will use the minimum of the corresponding variable in `samples`. maxs : {None, dict}, optional Maximum value for the axis of each variable in `parameters`. If None, it will use the maximum of the corresponding variable in `samples`. expected_parameters : {None, dict}, optional Expected values of `parameters`, as a dictionary mapping parameter names -> values. A cross will be plotted at the location of the expected parameters on axes that plot any of the expected parameters. expected_parameters_color : {'r', string}, optional What color to make the expected parameters cross. plot_marginal : {True, bool} Plot the marginalized distribution on the diagonals. If False, the diagonal axes will be turned off. plot_scatter : {True, bool} Plot each sample point as a scatter plot. marginal_percentiles : {None, array} What percentiles to draw lines at on the 1D histograms. If None, will draw lines at `[5, 50, 95]` (i.e., the bounds on the upper 90th percentile and the median). marginal_title : bool, optional Add a title over the 1D marginal plots that gives an estimated value +/- uncertainty. The estimated value is the pecentile halfway between the max/min of ``maginal_percentiles``, while the uncertainty is given by the max/min of the ``marginal_percentiles. If no ``marginal_percentiles`` are specified, the median +/- 95/5 percentiles will be quoted. marginal_linestyle : str, optional What line style to use for the marginal histograms. contour_percentiles : {None, array} What percentile contours to draw on the scatter plots. If None, will plot the 50th and 90th percentiles. zvals : {None, array} An array to use for coloring the scatter plots. If None, scatter points will be the same color. show_colorbar : {True, bool} Show the colorbar of zvalues used for the scatter points. A ValueError will be raised if zvals is None and this is True. cbar_label : {None, str} Specify a label to add to the colorbar. vmin: {None, float}, optional Minimum value for the colorbar. If None, will use the minimum of zvals. vmax: {None, float}, optional Maximum value for the colorbar. If None, will use the maxmimum of zvals. scatter_cmap : {'plasma', string} The color map to use for the scatter points. Default is 'plasma'. scatter_log_cmap : boolean Should the scatter point coloring be on a log scale? Default False log_parameters : list or None Which parameters should be plotted on a log scale plot_density : {False, bool} Plot the density of points as a color map. plot_contours : {True, bool} Draw contours showing the 50th and 90th percentile confidence regions. density_cmap : {'viridis', string} The color map to use for the density plot. contour_color : {None, string} The color to use for the contour lines. Defaults to white for density plots, navy for scatter plots without zvals, and black otherwise. label_contours : bool, optional Whether to label the contours. Default is True. contour_linestyles : list, optional Linestyles to use for the contours. Default (None) will use solid. use_kombine : {False, bool} Use kombine's KDE to calculate density. Otherwise, will use `scipy.stats.gaussian_kde.` Default is False. kdeargs : dict, optional Pass the given keyword arguments to the KDE. fig : pyplot.figure Use the given figure instead of creating one. axis_dict : dict Use the given dictionary of axes instead of creating one. Returns ------- fig : pyplot.figure The figure that was created. axis_dict : dict A dictionary mapping the parameter combinations to the axis and their location in the subplots grid; i.e., the key, values are: `{('param1', 'param2'): (pyplot.axes, row index, column index)}` """ if labels is None: labels = {p: p for p in parameters} if log_parameters is None: log_parameters = [] # set up the figure with a grid of axes # if only plotting 2 parameters, make the marginal plots smaller nparams = len(parameters) if nparams == 2 and plot_marginal: width_ratios = [3, 1] height_ratios = [1, 3] else: width_ratios = height_ratios = None # only plot scatter if more than one parameter plot_scatter = plot_scatter and nparams > 1 # Sort zvals to get higher values on top in scatter plots if plot_scatter: if zvals is not None: sort_indices = zvals.argsort() zvals = zvals[sort_indices] samples = samples[sort_indices] if contour_color is None: contour_color = 'k' elif show_colorbar: raise ValueError("must provide z values to create a colorbar") else: # just make all scatter points same color zvals = 'gray' if plot_contours and contour_color is None: contour_color = 'navy' if plot_maxl: # make sure loglikelihood is provide if 'loglikelihood' not in samples.fieldnames: raise ValueError("plot-maxl requires loglikelihood") maxidx = samples['loglikelihood'].argmax() # create the axis grid if fig is None and axis_dict is None: fig, axis_dict = create_axes_grid( parameters, labels=labels, width_ratios=width_ratios, height_ratios=height_ratios, no_diagonals=not plot_marginal) # convert samples to a dictionary to avoid re-computing derived parameters # every time they are needed # only try to plot what's available sd = {} for p in parameters: try: sd[p] = samples[p] except (ValueError, TypeError, IndexError): continue samples = sd parameters = list(sd.keys()) # values for axis bounds if mins is None: mins = {p: samples[p].min() for p in parameters} else: # copy the dict mins = {p: val for p, val in mins.items()} if maxs is None: maxs = {p: samples[p].max() for p in parameters} else: # copy the dict maxs = {p: val for p, val in maxs.items()} # Diagonals... if plot_marginal: for pi, param in enumerate(parameters): ax, _, _ = axis_dict[param, param] # if only plotting 2 parameters and on the second parameter, # rotate the marginal plot rotated = nparams == 2 and pi == nparams-1 # see if there are expected values if expected_parameters is not None: try: expected_value = expected_parameters[param] except KeyError: expected_value = None else: expected_value = None create_marginalized_hist( ax, samples[param], label=labels[param], color=hist_color, fillcolor=fill_color, log_scale=param in log_parameters, plot_marginal_lines=plot_marginal_lines, linestyle=marginal_linestyle, linecolor=line_color, title=marginal_title, expected_value=expected_value, expected_color=expected_parameters_color, rotated=rotated, plot_min=mins[param], plot_max=maxs[param], percentiles=marginal_percentiles) # Off-diagonals... for px, py in axis_dict: if px == py or px not in parameters or py not in parameters: continue ax, _, _ = axis_dict[px, py] if plot_scatter: if plot_density: alpha = 0.3 else: alpha = 1. if scatter_log_cmap: cmap_norm = colors.LogNorm(vmin=vmin, vmax=vmax) else: cmap_norm = colors.Normalize(vmin=vmin, vmax=vmax) plt = ax.scatter(x=samples[px], y=samples[py], c=zvals, s=5, edgecolors='none', norm=cmap_norm, cmap=scatter_cmap, alpha=alpha, zorder=2) if plot_contours or plot_density: # Exclude out-of-bound regions # this is a bit kludgy; should probably figure out a better # solution to eventually allow for more than just m_p m_s if (px == 'm_p' and py == 'm_s') or (py == 'm_p' and px == 'm_s'): exclude_region = 'm_s > m_p' else: exclude_region = None create_density_plot( px, py, samples, plot_density=plot_density, plot_contours=plot_contours, cmap=density_cmap, percentiles=contour_percentiles, contour_color=contour_color, label_contours=label_contours, contour_linestyles=contour_linestyles, xmin=mins[px], xmax=maxs[px], ymin=mins[py], ymax=maxs[py], exclude_region=exclude_region, ax=ax, use_kombine=use_kombine, kdeargs=kdeargs) if plot_maxl: maxlx = samples[px][maxidx] maxly = samples[py][maxidx] ax.scatter(maxlx, maxly, marker='x', s=20, c=contour_color, zorder=5) if expected_parameters is not None: try: ax.axvline(expected_parameters[px], lw=1.5, color=expected_parameters_color, zorder=5) except KeyError: pass try: ax.axhline(expected_parameters[py], lw=1.5, color=expected_parameters_color, zorder=5) except KeyError: pass ax.set_xlim(mins[px], maxs[px]) ax.set_ylim(mins[py], maxs[py]) # adjust tick number for large number of plots if len(parameters) > 3: for px, py in axis_dict: ax, _, _ = axis_dict[px, py] ax.set_xticks(reduce_ticks(ax, 'x', maxticks=3)) ax.set_yticks(reduce_ticks(ax, 'y', maxticks=3)) if plot_scatter and show_colorbar: # compute font size based on fig size scale_fac = get_scale_fac(fig) fig.subplots_adjust(right=0.85, wspace=0.03) cbar_ax = fig.add_axes([0.9, 0.1, 0.03, 0.8]) cb = fig.colorbar(plt, cax=cbar_ax) if cbar_label is not None: cb.set_label(cbar_label, fontsize=12*scale_fac)*scale_fac) return fig, axis_dict
[docs] def remove_common_offset(arr): """Given an array of data, removes a common offset > 1000, returning the removed value. """ offset = 0 isneg = (arr <= 0).all() # make sure all values have the same sign if isneg or (arr >= 0).all(): # only remove offset if the minimum and maximum values are the same # order of magintude and > O(1000) minpwr = numpy.log10(abs(arr).min()) maxpwr = numpy.log10(abs(arr).max()) if numpy.floor(minpwr) == numpy.floor(maxpwr) and minpwr > 3: offset = numpy.floor(10**minpwr) if isneg: offset *= -1 arr = arr - offset return arr, int(offset)
[docs] def reduce_ticks(ax, which, maxticks=3): """Given a pyplot axis, resamples its `which`-axis ticks such that are at most `maxticks` left. Parameters ---------- ax : axis The axis to adjust. which : {'x' | 'y'} Which axis to adjust. maxticks : {3, int} Maximum number of ticks to use. Returns ------- array An array of the selected ticks. """ ticks = getattr(ax, 'get_{}ticks'.format(which))() if len(ticks) > maxticks: # make sure the left/right value is not at the edge minax, maxax = getattr(ax, 'get_{}lim'.format(which))() dw = abs(maxax-minax)/10. start_idx, end_idx = 0, len(ticks) if ticks[0] < minax + dw: start_idx += 1 if ticks[-1] > maxax - dw: end_idx -= 1 # get reduction factor fac = int(len(ticks) / maxticks) ticks = ticks[start_idx:end_idx:fac] return ticks