Source code for pymatgen.util.plotting

# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.

from __future__ import division, unicode_literals
import math
import numpy as np

from monty.dev import deprecated

from pymatgen.core.periodic_table import Element

"""
Utilities for generating nicer plots.
"""


__author__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
__maintainer__ = "Shyue Ping Ong"
__email__ = "shyuep@gmail.com"
__date__ = "Mar 13, 2012"


[docs]def pretty_plot(width=8, height=None, plt=None, dpi=None, color_cycle=("qualitative", "Set1_9")): """ Provides a publication quality plot, with nice defaults for font sizes etc. Args: width (float): Width of plot in inches. Defaults to 8in. height (float): Height of plot in inches. Defaults to width * golden ratio. plt (matplotlib.pyplot): If plt is supplied, changes will be made to an existing plot. Otherwise, a new plot will be created. dpi (int): Sets dot per inch for figure. Defaults to 300. color_cycle (tuple): Set the color cycle for new plots to one of the color sets in palettable. Defaults to a qualitative Set1_9. Returns: Matplotlib plot object with properly sized fonts. """ ticksize = int(width * 2.5) golden_ratio = (math.sqrt(5) - 1) / 2 if not height: height = int(width * golden_ratio) if plt is None: import matplotlib.pyplot as plt import importlib mod = importlib.import_module("palettable.colorbrewer.%s" % color_cycle[0]) colors = getattr(mod, color_cycle[1]).mpl_colors from cycler import cycler plt.figure(figsize=(width, height), facecolor="w", dpi=dpi) ax = plt.gca() ax.set_prop_cycle(cycler('color', colors)) else: fig = plt.gcf() fig.set_size_inches(width, height) plt.xticks(fontsize=ticksize) plt.yticks(fontsize=ticksize) ax = plt.gca() ax.set_title(ax.get_title(), size=width * 4) labelsize = int(width * 3) ax.set_xlabel(ax.get_xlabel(), size=labelsize) ax.set_ylabel(ax.get_ylabel(), size=labelsize) return plt
@deprecated(pretty_plot, "get_publication_quality_plot has been renamed " "pretty_plot. This stub will be removed in pmg 2018.01.01.") def get_publication_quality_plot(*args, **kwargs): return pretty_plot(*args, **kwargs)
[docs]def pretty_plot_two_axis(x, y1, y2, xlabel=None, y1label=None, y2label=None, width=8, height=None, dpi=300): """ Variant of pretty_plot that does a dual axis plot. Adapted from matplotlib examples. Makes it easier to create plots with different axes. Args: x (np.ndarray/list): Data for x-axis. y1 (dict/np.ndarray/list): Data for y1 axis (left). If a dict, it will be interpreted as a {label: sequence}. y2 (dict/np.ndarray/list): Data for y2 axis (right). If a dict, it will be interpreted as a {label: sequence}. xlabel (str): If not None, this will be the label for the x-axis. y1label (str): If not None, this will be the label for the y1-axis. y2label (str): If not None, this will be the label for the y2-axis. width (float): Width of plot in inches. Defaults to 8in. height (float): Height of plot in inches. Defaults to width * golden ratio. dpi (int): Sets dot per inch for figure. Defaults to 300. Returns: matplotlib.pyplot """ import palettable.colorbrewer.diverging colors = palettable.colorbrewer.diverging.RdYlBu_4.mpl_colors c1 = colors[0] c2 = colors[-1] golden_ratio = (math.sqrt(5) - 1) / 2 if not height: height = int(width * golden_ratio) import matplotlib.pyplot as plt width = 12 labelsize = int(width * 3) ticksize = int(width * 2.5) styles = ["-", "--", "-.", "."] fig, ax1 = plt.subplots() fig.set_size_inches((width, height)) if dpi: fig.set_dpi(dpi) if isinstance(y1, dict): for i, (k, v) in enumerate(y1.items()): ax1.plot(x, v, c=c1, marker='s', ls=styles[i % len(styles)], label=k) ax1.legend(fontsize=labelsize) else: ax1.plot(x, y1, c=c1, marker='s', ls='-') if xlabel: ax1.set_xlabel(xlabel, fontsize=labelsize) if y1label: # Make the y-axis label, ticks and tick labels match the line color. ax1.set_ylabel(y1label, color=c1, fontsize=labelsize) ax1.tick_params('x', labelsize=ticksize) ax1.tick_params('y', colors=c1, labelsize=ticksize) ax2 = ax1.twinx() if isinstance(y2, dict): for i, (k, v) in enumerate(y2.items()): ax2.plot(x, v, c=c2, marker='o', ls=styles[i % len(styles)], label=k) ax2.legend(fontsize=labelsize) else: ax2.plot(x, y2, c=c2, marker='o', ls='-') if y2label: # Make the y-axis label, ticks and tick labels match the line color. ax2.set_ylabel(y2label, color=c2, fontsize=labelsize) ax2.tick_params('y', colors=c2, labelsize=ticksize) return plt
[docs]def pretty_polyfit_plot(x, y, deg=1, xlabel=None, ylabel=None, **kwargs): """ Convenience method to plot data with trend lines based on polynomial fit. Args: x: Sequence of x data. y: Sequence of y data. deg (int): Degree of polynomial. Defaults to 1. xlabel (str): Label for x-axis. ylabel (str): Label for y-axis. \\*\\*kwargs: Keyword args passed to pretty_plot. Returns: matplotlib.pyplot object. """ plt = pretty_plot(**kwargs) pp = np.polyfit(x, y, deg) xp = np.linspace(min(x), max(x), 200) plt.plot(xp, np.polyval(pp, xp), 'k--', x, y, 'o') if xlabel: plt.xlabel(xlabel) if ylabel: plt.ylabel(ylabel) return plt
[docs]def periodic_table_heatmap(elemental_data, cbar_label="", show_plot=False, cmap="YlOrRd", blank_color="grey", show_value=True): """ A static method that generates a heat map overlapped on a periodic table. Args: elemental_data (dict): A dictionary with the element as a key and a value assigned to it, e.g. surface energy and frequency, etc. Elements missing in the elemental_data will be grey by default in the final table elemental_data={"Fe": 4.2, "O": 5.0}. cbar_label (string): Label of the colorbar. Default is "". figure_name (string): Name of the plot (absolute path) being saved if not None. show_plot (bool): Whether to show the heatmap. Default is False. cmap (string): Color scheme of the heatmap. Default is 'coolwarm'. blank_color (string): Color assigned for the missing elements in elemental_data. Default is "grey". """ # Convert elemental data in the form of numpy array for plotting. max_val = max(elemental_data.values()) min_val = min(elemental_data.values()) value_table = np.empty((9, 18)) * np.nan blank_value = min_val - 0.01 for el in Element: value = elemental_data.get(el.symbol, blank_value) value_table[el.row - 1, el.group - 1] = value # Initialize the plt object import matplotlib.pyplot as plt fig, ax = plt.subplots() plt.gcf().set_size_inches(12, 8) # We set nan type values to masked values (ie blank spaces) data_mask = np.ma.masked_invalid(value_table.tolist()) heatmap = ax.pcolor(data_mask, cmap=cmap, edgecolors='w', linewidths=1, vmin=min_val-0.001, vmax=max_val+0.001) cbar = fig.colorbar(heatmap) # Grey out missing elements in input data cbar.cmap.set_under(blank_color) cbar.set_label(cbar_label, rotation=270, labelpad=15) cbar.ax.tick_params(labelsize=14) # Refine and make the table look nice ax.axis('off') ax.invert_yaxis() # Label each block with corresponding element and value for i, row in enumerate(value_table): for j, el in enumerate(row): if not np.isnan(el): symbol = Element.from_row_and_group(i+1, j+1).symbol plt.text(j + 0.5, i + 0.25, symbol, horizontalalignment='center', verticalalignment='center', fontsize=14) if el != blank_value and show_value: plt.text(j + 0.5, i + 0.5, "%.2f" % (el), horizontalalignment='center', verticalalignment='center', fontsize=10) plt.tight_layout() if show_plot: plt.show() return plt
[docs]def get_ax_fig_plt(ax=None): """ Helper function used in plot functions supporting an optional Axes argument. If ax is None, we build the `matplotlib` figure and create the Axes else we return the current active figure. Returns: ax: :class:`Axes` object figure: matplotlib figure plt: matplotlib pyplot module. """ import matplotlib.pyplot as plt if ax is None: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) else: fig = plt.gcf() return ax, fig, plt
[docs]def get_ax3d_fig_plt(ax=None): """ Helper function used in plot functions supporting an optional Axes3D argument. If ax is None, we build the `matplotlib` figure and create the Axes3D else we return the current active figure. Returns: ax: :class:`Axes` object figure: matplotlib figure plt: matplotlib pyplot module. """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import axes3d if ax is None: fig = plt.figure() ax = axes3d.Axes3D(fig) else: fig = plt.gcf() return ax, fig, plt
[docs]def get_axarray_fig_plt(ax_array, nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw): """ Helper function used in plot functions that accept an optional array of Axes as argument. If ax_array is None, we build the `matplotlib` figure and create the array of Axes by calling plt.subplots else we return the current active figure. Returns: ax: Array of :class:`Axes` objects figure: matplotlib figure plt: matplotlib pyplot module. """ import matplotlib.pyplot as plt if ax_array is None: fig, ax_array = plt.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze, subplot_kw=subplot_kw, gridspec_kw=gridspec_kw, **fig_kw) else: fig = plt.gcf() if squeeze: ax_array = np.array(ax_array).ravel() if len(ax_array) == 1: ax_array = ax_array[1] return ax_array, fig, plt
[docs]def add_fig_kwargs(func): """ Decorator that adds keyword arguments for functions returning matplotlib figures. The function should return either a matplotlib figure or None to signal some sort of error/unexpected event. See doc string below for the list of supported options. """ from functools import wraps @wraps(func) def wrapper(*args, **kwargs): # pop the kwds used by the decorator. title = kwargs.pop("title", None) size_kwargs = kwargs.pop("size_kwargs", None) show = kwargs.pop("show", True) savefig = kwargs.pop("savefig", None) tight_layout = kwargs.pop("tight_layout", False) # Call func and return immediately if None is returned. fig = func(*args, **kwargs) if fig is None: return fig # Operate on matplotlib figure. if title is not None: fig.suptitle(title) if size_kwargs is not None: fig.set_size_inches(size_kwargs.pop("w"), size_kwargs.pop("h"), **size_kwargs) if savefig: fig.savefig(savefig) if tight_layout: fig.tight_layout() if show: import matplotlib.pyplot as plt plt.show() return fig # Add docstring to the decorated method. s = "\n" + """\ keyword arguments controlling the display of the figure: ================ ==================================================== kwargs Meaning ================ ==================================================== title Title of the plot (Default: None). show True to show the figure (default: True). savefig 'abc.png' or 'abc.eps' to save the figure to a file. size_kwargs Dictionary with options passed to fig.set_size_inches example: size_kwargs=dict(w=3, h=4) tight_layout True if to call fig.tight_layout (default: False) ================ ====================================================""" if wrapper.__doc__ is not None: # Add s at the end of the docstring. wrapper.__doc__ += "\n" + s else: # Use s wrapper.__doc__ = s return wrapper