Source code for phasik.drawing.drawing

"""
Useful drawing functions
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm

from phasik.utils.utils import get_extrema_of_binary_series

__all__ = [
    'plot_events',
    'plot_interval',
    'plot_phases',
    'threshold_plot',
]


[docs]def plot_events(events, ax=None, text_y_pos=None, text_x_offset=0, period=None, n_periods=1, add_labels=True, orientation="vertical", zorder=-1, alpha=1, va="bottom"): """Plot events as vertical lines on axes Parameters ---------- events : list of tuples (time, name, line_style) time - time at which the event occurred name - the name of the event line_style - any string accepted by matplotlib.lines.Line2D.set_linestyle ax : matplotlib.Axes, optional Axes on which to plot the events text_y_pos : float, optional Height at which to place the name of the event (default None) text_x_offset : float, optional Distance along x-axis by which to offset the placement of the event name (default 0) period: float or None, optional Length of time of one period, if events repeat periodically. n_periods : int, optional Number of periods to draw, if events repeat periodically. add_labels : bool, optional Wether to display the label of each vertical line, True by default. Returns ------- None """ if ax is None: ax = plt.gca() if text_y_pos is None: text_y_pos = 1.01 * ax.get_ylim()[1] if text_x_offset < 0: text_x_offset = -text_x_offset for event in events: time, name, line_style = event if orientation=="vertical": ax.axvline(x=time, c='k', ls=line_style, label=name, zorder=zorder, alpha=alpha) text_x_pos = time - text_x_offset if time > 0 else time + text_x_offset if add_labels: ax.text(text_x_pos, text_y_pos, name, fontsize='small', rotation=90, va=va, ha='center') elif orientation=="horizontal": ax.axhline(y=time, c='k', ls=line_style, label=name, zorder=zorder, alpha=alpha) else: print("WARNING: wrong orientation, must be one of {'vertical', 'horizontal'}") if period!=None and n_periods>1: # repeat events over n periods for k in range(1, n_periods): for event in events: time, name, line_style = event time += period * k if orientation=="vertical": ax.axvline(x=time, c='k', ls=line_style, label=name, zorder=zorder, alpha=alpha) text_x_pos = time - text_x_offset if time > 0 else time + text_x_offset if add_labels: ax.text(text_x_pos, text_y_pos, name, fontsize='small', rotation=90, va=va, ha='center') elif orientation=="horizontal": ax.axhline(y=time, c='k', ls=line_style, label=name, zorder=zorder, alpha=alpha) else: print("WARNING: wrong orientation, must be one of {'vertical', 'horizontal'}")
[docs]def plot_phases(phases, ax=None, y_pos=None, ymin=0, ymax=1, t_offset=0): """Plot phases as shaded regions on axes Parameters ---------- phases : list of tuples (start_time, end_time, name) ax : matplotlib.Axes Axes on which to plot the phases y_pos : float or None, optional Height at which to place the name of the phase ymin : float, optional Height at which to start shaded region (default 0) ymax : float, optional Height at which to stop shaded region (default 1) t_offset : float, optional Offset of phase on the time axis Returns ------- None """ if ax is None: ax = plt.gca() y_pos = y_pos if y_pos is not None else 1.01 y_lim = ax.get_ylim() absolute_y_pos = y_lim[0] + y_pos * (y_lim[1] - y_lim[0]) for i, phase in enumerate(phases): start_time, end_time, name = phase start_time += t_offset end_time += t_offset mid_time = (start_time + end_time)/2 alpha_interval = 0.5 / len(phases) ax.axvspan(xmin=start_time, xmax=end_time, ymin=ymin, ymax=ymax, color='k', alpha=alpha_interval*(i+1)) ax.text(mid_time, absolute_y_pos, name, fontweight='bold', va='center', ha='center')
[docs]def plot_interval(binary_series, times, y=0, peak=None, color='k', ax=None, zorder=0): """Plot a binary series as a sequence of coloured intervals Specifically, when a binary series has value 1, plot it as a continuous rectangular interval. When it has value 0 do nothing. Parameters ---------- binary_series : ndarray 2D array of binary data to plot times : ndarray 1D array consisting of the corresponding time points y : float, optional Height (y-axis value) at which to plot the interval (default 0) color : str, optional Color to use for the intervals (default 'k') ax : matplotlib.Axes, optional Axes to plot on zorder : int, optional Height on the z-axis which to plot the interval (default 0) Returns ------- None """ if ax is None: ax = plt.gca() xmins, xmaxs = get_extrema_of_binary_series(binary_series, times) rect_height = 0.5 for xmin, xmax in zip(xmins, xmaxs): rect = patches.Rectangle((xmin, y), xmax-xmin, rect_height, fill=True, color=color, zorder=zorder) ax.add_patch(rect) if peak is not None: ax.plot(peak, y + rect_height / 2, 'r*')
[docs]def threshold_plot(x, y, threshold, color_below_threshold, color_above_threshold, ax=None): """Plot values above a certain threshold in a particular colour Parameters ---------- x : array 1D array of alues to plot along x-axis y : array 1D array of values to plot along y-axis threshold : float Only plot in colour the points (x,y) with y >= threshold colour_below_threshold : str Colour to use for points below threshold colour_above_threshold : list of str Colour to use for points above threshold ax : matplotlib.Axes, optional Axes to use Returns ------- None """ if ax is None: ax = plt.gca() # Create a colormap for red, green and blue and a norm to color # f' < -0.5 red, f' > 0.5 blue, and the rest green cmap = ListedColormap([color_below_threshold, color_above_threshold]) norm = BoundaryNorm([np.min(y), threshold, np.max(y)], cmap.N) # Create a set of line segments so that we can color them individually # This creates the points as a N x 1 x 2 array so that we can stack points # together easily to get the segments. The segments array for line collection # needs to be numlines x points per line x 2 (x and y) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) # Create the line collection object, setting the colormapping parameters. # Have to set the actual values used for colormapping separately. line_collection = LineCollection(segments, cmap=cmap, norm=norm) line_collection.set_array(y) ax.add_collection(line_collection) ax.set_xlim(np.min(x), np.max(x)) ax.set_ylim(np.min(y)*1.1, np.max(y)*1.1) return line_collection