API
module correlation
Plot distributions of correlation matrix eigenvalues.
Global Variables
- TYPE_CHECKING
function marchenko_pastur_pdf
marchenko_pastur_pdf(x: 'float', gamma: 'float', sigma: 'float' = 1) → float
Generate Marchenko-Pastur probability distribution which describes the density of singular values of large rectangular random matrices.
See https://wikipedia.org/wiki/Marchenko-Pastur_distribution.
By comparing the eigenvalue distribution of a correlation matrix to this PDF, one can gauge the significance of correlations.
Args:
x
(float): Position at which to compute probability density.gamma
(float): Also referred to as lambda. The distribution’s main parameter that measures how well sampled the data is.sigma
(float, optional): Standard deviation of random variables assumed to be independent identically distributed. Defaults to 1 as appropriate for correlation matrices.
Returns:
float
: Marchenko-Pastur density for given gamma at x
function marchenko_pastur
marchenko_pastur(
matrix: 'ArrayLike',
gamma: 'float',
sigma: 'float' = 1,
filter_high_evals: 'bool' = False,
ax: 'Axes | None' = None
) → Axes
Plot the eigenvalue distribution of a symmetric matrix (usually a correlation matrix) against the Marchenko Pastur distribution.
The probability of a random matrix having eigenvalues >= (1 + sqrt(gamma))^2 in the absence of any signal is vanishingly small. Thus, if eigenvalues larger than that appear, they correspond to statistically significant signals.
Args:
matrix
(ArrayLike): 2d arraygamma
(float): The Marchenko-Pastur ratio of random variables to observation count. E.g. for N=1000 variables and p=500 observations of each, gamma = p/N = 1/2.sigma
(float, optional): Standard deviation of random variables. Defaults to 1.filter_high_evals
(bool, optional): Whether to filter out eigenvalues larger than theoretical random maximum. Useful for focusing the plot on the area of the MP PDF. Defaults to False.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
plt.Axes
: matplotlib Axes object
module cumulative
Plot the cumulative distribution of residuals and absolute errors.
Global Variables
- TYPE_CHECKING
function cumulative_residual
cumulative_residual(
res: 'ArrayLike',
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Plot the empirical cumulative distribution for the residuals (y - mu).
Args:
res
(array): Residuals between y_true and y_pred, i.e. targets - model predictions.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.**kwargs
: Additional keyword arguments passed to ax.fill_between().
Returns:
plt.Axes
: matplotlib Axes object
function cumulative_error
cumulative_error(
abs_err: 'ArrayLike',
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Plot the empirical cumulative distribution of the absolute errors.
abs(y_true - y_pred).
Args:
abs_err
(array): Absolute error between y_true and y_pred, i.e. abs(targets - model predictions).ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.**kwargs
: Additional keyword arguments passed to ax.plot().
Returns:
plt.Axes
: matplotlib Axes object
module histograms
Histograms and bar charts.
Global Variables
- TYPE_CHECKING
- plotly_key
function true_pred_hist
true_pred_hist(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
y_std: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
cmap: 'str' = 'hot',
truth_color: 'str' = 'blue',
true_label: 'str' = '$y_\mathrm{true}$',
pred_label: 'str' = '$y_\mathrm{pred}$',
**kwargs: 'Any'
) → Axes
Plot a histogram of model predictions with bars colored by the mean uncertainty of predictions in that bin. Overlaid by a more transparent histogram of ground truth values.
Args:
y_true
(array | str): ground truth targets as array or df column name.y_pred
(array | str): model predictions as array or df column name.y_std
(array | str): model uncertainty as array or df column name.df
(DataFrame, optional): DataFrame containing y_true, y_pred, and y_std.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.cmap
(str, optional): string identifier of a plt colormap. Defaults to ‘hot’.truth_color
(str, optional): Face color to use for y_true bars. Defaults to ‘blue’.true_label
(str, optional): Label for ytrue bars. Defaults to ‘$y\mathrmtrue$‘.pred_label
(str, optional): Label for ypred bars. Defaults to ‘$y\mathrmtrue$‘.**kwargs
: Additional keyword arguments to pass to ax.hist().
Returns:
plt.Axes
: matplotlib Axes object
function spacegroup_hist
spacegroup_hist(
data: 'Sequence[int | str | Structure] | Series',
show_counts: 'bool' = True,
xticks: "Literal['all', 'crys_sys_edges'] | int" = 20,
show_empty_bins: 'bool' = False,
ax: 'Axes | None' = None,
backend: 'Backend' = 'plotly',
text_kwargs: 'dict[str, Any] | None' = None,
**kwargs: 'Any'
) → Axes | Figure
Plot a histogram of spacegroups shaded by crystal system.
Args:
data
(list[int | str | Structure] | pd.Series): Space group strings or numbers (from 1 - 230) or pymatgen structures.show_counts
(bool, optional): Whether to count the number of items in each crystal system. Defaults to True.xticks
(‘all’ | ‘crys_sys_edges’ | int, optional): Where to add x-ticks. An integer will add ticks below that number of tallest bars. Defaults to 20. ‘all’ will show below all bars, ‘crys_sys_edges’ only at the edge from one crystal system to another.show_empty_bins
(bool, optional): Whether to include a 0-height bar for missing space groups missing from the data. Currently only implemented for numbers, not symbols. Defaults to False.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.backend
(“matplotlib” | “plotly”, optional): Which backend to use for plotting. Defaults to “plotly”.text_kwargs
(dict, optional): Keyword arguments passed to matplotlib.Axes.text(). Defaults to None. Has no effect if backend is “plotly”.kwargs
: Keywords passed to pd.Series.plot.bar() or plotly.express.bar().
Returns:
plt.Axes | go.Figure
: matplotlib Axes or plotly Figure depending on backend.
function elements_hist
elements_hist(
formulas: 'ElemValues',
count_mode: 'CountMode' = 'composition',
log: 'bool' = False,
keep_top: 'int | None' = None,
ax: 'Axes | None' = None,
bar_values: "Literal['percent', 'count'] | None" = 'percent',
h_offset: 'int' = 0,
v_offset: 'int' = 10,
rotation: 'int' = 45,
**kwargs: 'Any'
) → Axes
Plot a histogram of elements (e.g. to show occurrence in a dataset).
Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).
Args:
formulas
(list[str]): compositional strings, e.g. [“Fe2O3”, “Bi2Te3”]. count_mode (‘composition’ | ‘fractional_composition’ | ‘reduced_composition’): Reduce or normalize compositions before counting. See count_elements() for details. Only used when formulas is list of composition strings/objects.log
(bool, optional): Whether y-axis is log or linear. Defaults to False.keep_top
(int | None): Display only the top n elements by prevalence.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.bar_values
(‘percent’|‘count’|None): ‘percent’ (default) annotates bars with the percentage each element makes up in the total element count. ‘count’ displays count itself. None removes bar labels.h_offset
(int): Horizontal offset for bar height labels. Defaults to 0.v_offset
(int): Vertical offset for bar height labels. Defaults to 10.rotation
(int): Bar label angle. Defaults to 45.**kwargs (int)
: Keyword arguments passed to pandas.Series.plot.bar().
Returns:
plt.Axes
: matplotlib Axes object
module io
I/O utilities for saving figures and dataframes to various image formats.
Global Variables
- DEFAULT_BUFFER_SIZE
- SEEK_SET
- SEEK_CUR
- SEEK_END
- TYPE_CHECKING
- ROOT
- DEFAULT_DF_STYLES
- ALLOW_TABLE_SCROLL
- HIDE_SCROLL_BAR
function save_fig
save_fig(
fig: 'Figure | Figure | Axes',
path: 'str',
plotly_config: 'dict[str, Any] | None' = None,
env_disable: 'Sequence[str]' = ('CI',),
pdf_sleep: 'float' = 0.6,
style: 'str' = '',
prec: 'int | None' = None,
template: 'str | None' = None,
**kwargs: 'Any'
) → None
Write a plotly or matplotlib figure to disk (as HTML/PDF/SVG/…).
If the file is has .svelte extension, insert {...$$props}
into the figure’s top-level div so it can be later styled and customized from Svelte code.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): Plotly or matplotlib Figure or matplotlib Axes object.path
(str): Path to image file that will be created.plotly_config
(dict, optional): Configuration options for fig.write_html(). Defaults to dict(showTips=False, responsive=True, modeBarButtonsToRemove= [“lasso2d”, “select2d”, “autoScale2d”, “toImage”]).See https
: //plotly.com/python/configuration-options.env_disable
(list[str], optional): Do nothing if any of these environment variables are set. Defaults to (“CI”,).pdf_sleep
(float, optional): Minimum time in seconds to wait before writing a plotly figure to PDF file. Workaround for this plotly issuehttps
: //github.com/plotly/plotly.py/issues/3469. Defaults to 0.6. Has no effect on matplotlib figures.style
(str, optional): CSS style string to be inserted into the HTML file. Defaults to "". Only used if path ends with .svelte or .html.prec
(int, optional): Number of significant figures to keep for any float in the exported file. Defaults to None (no rounding). Sensible values are usually 4, 5, 6.template
(str, optional): Temporary plotly to apply to the figure before saving. Will be reset to the original after. Defaults to “pymatviz_white” if path ends with .pdf or .pdfa, else None. Set to None to disable. Only used if fig is a plotly figure.
**kwargs
: Keyword arguments passed to fig.write_html().
function save_and_compress_svg
save_and_compress_svg(fig: 'Figure | Figure | Axes', filename: 'str') → None
Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with svgo CLI if available in PATH.
Args:
fig
(Figure): Plotly or matplotlib Figure/Axes instance.filename
(str): Name of SVG file (w/o extension).
Raises:
ValueError
: If fig is None and plt.gcf() is empty.
function df_to_pdf
df_to_pdf(
styler: 'Styler',
file_path: 'str | Path',
crop: 'bool' = True,
size: 'str | None' = None,
style: 'str' = '',
styler_css: 'bool | dict[str, str]' = True,
**kwargs: 'Any'
) → None
Export a pandas Styler to PDF with WeasyPrint.
Args:
styler
(Styler): Styler object to export.file_path
(str): Path to save the PDF to. Requires WeasyPrint.crop
(bool): Whether to crop the PDF margins. Requires pdfCropMargins. Defaults to True. Be careful to set size correctly (not much too large as is the default) if you set crop=False.size
(str): Page size. Defaults to “4cm n_cols x 2cm n_rows”(width x height). See https
: //developer.mozilla.org/@page for ‘landscape’ and other options.style
(str): CSS style string to be inserted into the HTML file. Defaults to "".styler_css
(bool | dict[str, str]): Whether to apply some sensible default CSS to the pandas Styler. Defaults to True. If dict, keys are selectors andvalues CSS strings. Example
: dict(“td, th”: “border: none; padding: 4px;“)**kwargs
: Keyword arguments passed to Styler.to_html().
function normalize_and_crop_pdf
normalize_and_crop_pdf(
file_path: 'str | Path',
on_gs_not_found: "Literal['ignore', 'warn', 'error']" = 'warn'
) → None
Normalize a PDF using Ghostscript and then crop it. Without gs normalization, pdfCropMargins sometimes corrupts the PDF.
Args:
file_path
(str | Path): Path to the PDF file.on_gs_not_found
(‘ignore’ | ‘warn’ | ‘error’, optional): What to do if Ghostscript is not found in PATH. Defaults to ‘warn’.
function df_to_html_table
df_to_html_table(
styler: 'Styler',
file_path: 'str | Path | None' = None,
inline_props: 'str | None' = '',
script: 'str | None' = '',
styles: 'str | None' = 'table { overflow: scroll; max-width: 100%; display: block; }\ntable {\n scrollbar-width: none; /* Firefox */\n}\ntable::-webkit-scrollbar {\n display: none; /* Safari and Chrome */\n}',
styler_css: 'bool | dict[str, str]' = True,
sortable: 'bool' = True,
post_process: 'Callable[[str], str] | None' = None,
**kwargs: 'Any'
) → str
Convert a pandas Styler to a svelte table.
Args:
styler
(Styler): Styler object to export.file_path
(str): Path to the file to write the svelte table to.inline_props
(str): Inline props to pass to the table element. Example:"class='table' style='width
: 100%’“. Defaults to "".script
(str): JavaScript string to insert above the table. Will replace the opening HTML opening table tag to allow passing props to it. The default script uses …props to enable Svelte props forwarding to the table element. See source code to inspect default script.styles
(str): CSS rules to insert at the bottom of the style tag. Defaults to TABLE_SCROLL_CSS.styler_css
(bool | dict[str, str]): Whether to apply some sensible default CSS to the pandas Styler. Defaults to True. If dict, keys are CSS selectors and values CSS strings. Example:dict("td, th"
: “border: none; padding: 4px 6px;“)sortable
(bool): Whether to enable sorting the table by clicking on column headers. Defaults to True. Requires npm install svelte-zoo.post_process
(Callable[[str], str]): Function to post-process the HTML string before writing it to file. Defaults to None.**kwargs
: Keyword arguments passed to Styler.to_html().
Returns:
str
: pandas Styler as HTML.
class BufferedIOBase
Base class for buffered IO objects.
The main difference with RawIOBase is that the read() method supports omitting the size argument, and does not have a default implementation that defers to readinto().
In addition, read(), readinto() and write() may raise BlockingIOError if the underlying raw stream is in non-blocking mode and not ready; unlike their raw counterparts, they will never return None.
A typical implementation should not inherit from a RawIOBase implementation, but wrap one.
class IOBase
The abstract base class for all I/O classes.
This class provides dummy implementations for many methods that derived classes can override selectively; the default implementations represent a file that cannot be read, written or seeked.
Even though IOBase does not declare read, readinto, or write because their signatures will vary, implementations and clients should consider those methods part of the interface. Also, implementations may raise UnsupportedOperation when operations they do not support are called.
The basic type used for binary data read from or written to a file is bytes. Other bytes-like objects are accepted as method arguments too. In some cases (such as readinto), a writable object is required. Text I/O classes work with str data.
Note that calling any method (except additional calls to close(), which are ignored) on a closed stream should raise a ValueError.
IOBase (and its subclasses) support the iterator protocol, meaning that an IOBase object can be iterated over yielding the lines in a stream.
IOBase also supports the :keyword:with
statement. In this example, fp is closed after the suite of the with statement is complete:
with open(‘spam.txt’, ‘r’) as fp: fp.write(‘Spam and eggs!‘)
class RawIOBase
Base class for raw binary I/O.
class TextIOBase
Base class for text I/O.
This class provides a character and line based interface to stream I/O. There is no readinto method because Python’s character strings are immutable.
class UnsupportedOperation
class TqdmDownload
Progress bar for urllib.request.urlretrieve file download.
Adapted from official TqdmUpTo example. See https://github.com/tqdm/tqdm/blob/4c956c20b83be4312460fc0c4812eeb3fef5e7df/README.rst#hooks-and-callbacks
method __init__
__init__(*args: 'Any', **kwargs: 'Any') → None
Sets default values appropriate for file downloads for unit, unit_scale, unit_divisor, miniters, desc.
property format_dict
Public API for read-only member access.
method update_to
update_to(
n_blocks: 'int' = 1,
block_size: 'int' = 1,
total_size: 'int | None' = None
) → bool | None
Update hook for urlretrieve.
Args:
n_blocks
(int, optional): Number of blocks transferred so far. Default = 1.block_size
(int, optional): Size of each block (in tqdm units). Default = 1.total_size
(int, optional): Total size (in tqdm units). If None, remains unchanged. Defaults to None.
Returns:
bool | None
: True if tqdm.display() was triggered.
module parity
Parity, residual and density plots.
Global Variables
- TYPE_CHECKING
function hist_density
hist_density(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
sort: 'bool' = True,
bins: 'int' = 100,
method: 'str' = 'nearest'
) → tuple[ArrayLike, ArrayLike, ArrayLike]
Return an approximate density of 2d points.
Args:
x
(array | str): x-values or dataframe column name.y
(array | str): y-values or dataframe column name.df
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.sort
(bool, optional): Whether to sort points by density so that densest points are plotted last. Defaults to True.bins
(int, optional): Number of bins (histogram resolution). Defaults to 100.method
(str, optional): Interpolation method. Defaults to “nearest”. See scipy.interpolate.interpn() for options.
Returns:
tuple[np.array, np.array, np.array]
: x and y values (sorted by density) and density itself
function density_scatter
density_scatter(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
log_density: 'bool' = True,
hist_density_kwargs: 'dict[str, Any] | None' = None,
color_bar: 'bool | dict[str, Any]' = True,
xlabel: 'str | None' = None,
ylabel: 'str | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
stats: 'bool | dict[str, Any]' = True,
**kwargs: 'Any'
) → Axes
Scatter plot colored (and optionally sorted) by density.
Args:
x
(array | str): x-values or dataframe column name.y
(array | str): y-values or dataframe column name.df
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.sort
(bool, optional): Whether to sort the data. Defaults to True.log_density
(bool, optional): Whether to log the density color scale. Defaults to True.hist_density_kwargs
(dict, optional): Passed to hist_density(). Use to change sort (by density, default True), bins (default 100), or method (for interpolation, default “nearest”).color_bar
(bool | dict, optional): Whether to add a color bar. Defaults to True. If dict, unpacked into ax.figure.colorbar(). E.g. dict(label=“Density”).xlabel
(str, optional): x-axis label. Defaults to “Actual”.ylabel
(str, optional): y-axis label. Defaults to “Predicted”.identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.stats
(bool | dict[str, Any], optional): Whether to display a text box with MAE and R^2. Defaults to True. Can be dict to pass kwargs to annotate_metrics(). E.g. stats=dict(loc=“upper left”, prefix=“Title”, prop=dict(fontsize=16)).**kwargs
: Passed to ax.scatter(). Defaults to dict(s=6) to control marker size. Other common keys are cmap, vmin, vamx, alpha, edgecolors, linewidths.
Returns: plt.Axes:
function scatter_with_err_bar
scatter_with_err_bar(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
xerr: 'ArrayLike | None' = None,
yerr: 'ArrayLike | None' = None,
ax: 'Axes | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
xlabel: 'str' = 'Actual',
ylabel: 'str' = 'Predicted',
title: 'str | None' = None,
**kwargs: 'Any'
) → Axes
Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty.
Args:
x
(array | str): x-values or dataframe column namey
(array | str): y-values or dataframe column namedf
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.xerr
(array, optional): Horizontal error bars. Defaults to None.yerr
(array, optional): Vertical error bars. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.xlabel
(str, optional): x-axis label. Defaults to “Actual”.ylabel
(str, optional): y-axis label. Defaults to “Predicted”.title
(str, optional): Plot tile. Defaults to None.**kwargs
: Additional keyword arguments to pass to ax.errorbar().
Returns:
plt.Axes
: matplotlib Axes object
function density_hexbin
density_hexbin(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
weights: 'ArrayLike | None' = None,
identity_line: 'bool | dict[str, Any]' = True,
best_fit_line: 'bool | dict[str, Any]' = True,
xlabel: 'str' = 'Actual',
ylabel: 'str' = 'Predicted',
cbar_label: 'str | None' = 'Density',
cbar_coords: 'tuple[float, float, float, float]' = (0.95, 0.03, 0.03, 0.7),
**kwargs: 'Any'
) → Axes
Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights.
Args:
x
(array): x-values or dataframe column name.y
(array): y-values or dataframe column name.df
(pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.weights
(array, optional): If given, these values are accumulated in the bins. Otherwise, every point has value 1. Must be of the same length as x and y. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add an parity line (y = x). Defaults to True. Pass a dict to customize line properties.best_fit_line
(bool | dict[str, Any], optional): Whether to add a best-fit line. Defaults to True. Pass a dict to customize line properties.xlabel
(str, optional): x-axis label. Defaults to “Actual”.ylabel
(str, optional): y-axis label. Defaults to “Predicted”.cbar_label
(str, optional): Color bar label. Defaults to “Density”.cbar_coords
(tuple[float, float, float, float], optional): Color bar positionand size
: [x, y, width, height] anchored at lower left corner. Defaults to (0.18, 0.8, 0.42, 0.05).**kwargs
: Additional keyword arguments to pass to ax.hexbin().
Returns:
plt.Axes
: matplotlib Axes object
function density_scatter_with_hist
density_scatter_with_hist(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
ax: 'Axes | None' = None,
**kwargs: 'Any'
) → Axes
Scatter plot colored (and optionally sorted) by density with histograms along each dimension.
function density_hexbin_with_hist
density_hexbin_with_hist(
x: 'ArrayLike | str',
y: 'ArrayLike | str',
df: 'DataFrame | None' = None,
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
**kwargs: 'Any'
) → Axes
Hexagonal-grid scatter plot colored by density or by third dimension passed color_by with histograms along each dimension.
function residual_vs_actual
residual_vs_actual(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
xlabel: 'str' = 'Actual value',
ylabel: 'str' = 'Residual ($y_\mathrm{true} - y_\mathrm{pred}$)',
**kwargs: 'Any'
) → Axes
Plot targets on the x-axis vs residuals (y_err = y_true - y_pred) on the y-axis.
Args:
y_true
(array): Ground truth valuesy_pred
(array): Model predictionsdf
(pd.DataFrame, optional): DataFrame with y_true and y_pred columns. Defaults to None.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.xlabel
(str, optional): x-axis label. Defaults to “Actual value”.ylabel
(str, optional): y-axis label. Defaults to'Residual ($y_\mathrm{true} - y_\mathrm{pred}$)'
.**kwargs
: Additional keyword arguments passed to plt.plot()
Returns:
plt.Axes
: matplotlib Axes object
module phonons
Plotting functions for pymatgen phonon band structures and density of states.
Global Variables
- TYPE_CHECKING
function pretty_sym_point
pretty_sym_point(symbol: 'str') → str
Convert a symbol to a pretty-printed version.
function get_band_xaxis_ticks
get_band_xaxis_ticks(
band_struct: 'PhononBands',
branches: 'Sequence[str] | set[str]' = ()
) → tuple[list[float], list[str]]
Get all ticks and labels for a band structure plot.
Returns:
tuple[list[float], list[str]]
: Ticks and labels for the x-axis of a band structure plot.branches
(Sequence[str]): Branches to plot. Defaults to empty tuple, meaning all branches are plotted.
function plot_phonon_bands
plot_phonon_bands(
band_structs: 'PhononBands | dict[str, PhononBands]',
line_kwargs: 'dict[str, Any] | None' = None,
branches: 'Sequence[str]' = (),
branch_mode: 'BranchMode' = 'union',
shaded_ys: 'dict[tuple[YMin, YMax], dict[str, Any]] | bool | None' = None,
**kwargs: 'Any'
) → Figure
Plot single or multiple pymatgen band structures using Plotly, focusing on the minimum set of overlapping branches.
Warning: Only tested with phonon band structures so far but plan is to extend to electronic band structures.
Args: band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.
line_kwargs
(dict[str, Any]): Passed to Plotly’s Figure.add_scatter method.branches
(Sequence[str]): Branches to plot. Defaults to empty tuple, meaning all branches are plotted.branch_mode
(“union” | “intersection”): Whether to plot union or intersection of branches in case of multiple band structures with non-overlapping branches. Defaults to “union”.shaded_ys
(dict[tuple[float | str, float | str], dict]): Keys are y-ranges (min, max) tuple and values are kwargs for shaded regions created by fig.add_hrect(). Defaults to single entry (0, “y_min”): dict(fillcolor=“gray”, opacity=0.07). “y_min” and “y_max” will be replaced with the figure’s y-axis range. dict(layer=“below”, row=“all”, col=“all”) is always passed to add_hrect but can be overridden by the user. Set to False to disable.**kwargs
: Passed to Plotly’s Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
function plot_phonon_dos
plot_phonon_dos(
doses: 'PhononDos | dict[str, PhononDos]',
stack: 'bool' = False,
sigma: 'float' = 0,
units: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz',
normalize: "Literal['max', 'sum', 'integral'] | None" = None,
last_peak_anno: 'str | None' = None,
**kwargs: 'Any'
) → Figure
Plot phonon DOS using Plotly.
Args:
doses
(PhononDos | dict[str, PhononDos]): PhononDos or dict of multiple.stack
(bool): Whether to plot the DOS as a stacked area graph. Defaults to False.sigma
(float): Standard deviation for Gaussian smearing. Defaults to None.units
(str): Units for the frequencies. Defaults to “THz”.legend
(dict): Legend configuration.normalize
(bool): Whether to normalize the DOS. Defaults to False.last_peak_anno
(str): Annotation for last DOS peak with f-string placeholders for key (of dict containing multiple DOSes), last_peak frequency and units. Defaults to None, meaning last peak annotation is disabled. Set to "" to enable with a sensible default string.**kwargs
: Passed to Plotly’s Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
function convert_frequencies
convert_frequencies(
frequencies: 'ndarray',
unit: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz'
) → ndarray
Convert frequencies from THz to specified units.
Args:
frequencies
(np.ndarray): Frequencies in THz.unit
(str): Target units. One of ‘THz’, ‘eV’, ‘meV’, ‘Ha’, ‘cm-1’.
Returns:
np.ndarray
: Converted frequencies.
function plot_phonon_bands_and_dos
plot_phonon_bands_and_dos(
band_structs: 'PhononBands | dict[str, PhononBands]',
doses: 'PhononDos | dict[str, PhononDos]',
bands_kwargs: 'dict[str, Any] | None' = None,
dos_kwargs: 'dict[str, Any] | None' = None,
subplot_kwargs: 'dict[str, Any] | None' = None,
all_line_kwargs: 'dict[str, Any] | None' = None,
per_line_kwargs: 'dict[str, dict[str, Any]] | None' = None,
**kwargs: 'Any'
) → Figure
Plot phonon DOS and band structure using Plotly.
Args:
doses
(PhononDos | dict[str, PhononDos]): PhononDos or dict of multiple. band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.bands_kwargs
(dict[str, Any]): Passed to Plotly’s Figure.add_scatter method.dos_kwargs
(dict[str, Any]): Passed to Plotly’s Figure.add_scatter method.subplot_kwargs
(dict[str, Any]): Passed to Plotly’s make_subplots method. Defaults to dict(shared_yaxes=True, column_widths=(0.8, 0.2), horizontal_spacing=0.01).all_line_kwargs
(dict[str, Any]): Passed to trace.update for each in fig.data. Modify line appearance for all traces. Defaults to None.per_line_kwargs
(dict[str, str]): Map of line labels to kwargs for trace.update. Modify line appearance for specific traces. Defaults to None.**kwargs
: Passed to Plotly’s Figure.add_scatter method.
Returns:
go.Figure
: Plotly figure object.
class PhononDBDoc
Dataclass for phonon DB docs.
method __init__
__init__(
structure: 'Structure',
phonon_bandstructure: 'PhononBands',
phonon_dos: 'PhononDos',
free_energies: 'list[float]',
internal_energies: 'list[float]',
heat_capacities: 'list[float]',
entropies: 'list[float]',
temps: 'list[float] | None' = None,
has_imaginary_modes: 'bool | None' = None,
primitive: 'Structure | None' = None,
supercell: 'list[list[int]] | None' = None,
nac_params: 'dict[str, Any] | None' = None,
thermal_displacement_data: 'dict[str, Any] | None' = None,
mp_id: 'str | None' = None,
formula: 'str | None' = None
) → None
module powerups
Powerups/enhancements such as parity lines, annotations and marginals for matplotlib and plotly figures.
Global Variables
- TYPE_CHECKING
- mpl_key
- plotly_key
function with_marginal_hist
with_marginal_hist(
xs: 'ArrayLike',
ys: 'ArrayLike',
cell: 'GridSpec | None' = None,
bins: 'int' = 100,
fig: 'Figure | Axes | None' = None
) → Axes
Call before creating a plot and use the returned ax_main
for all subsequent plotting ops to create a grid of plots with the main plot in the lower left and narrow histograms along its x- and/or y-axes displayed above and near the right edge.
Args:
xs
(array): Marginal histogram values along x-axis.ys
(array): Marginal histogram values along y-axis.cell
(GridSpec, optional): Cell of a plt GridSpec at which to add the grid of plots. Defaults to None.bins
(int, optional): Resolution/bin count of the histograms. Defaults to 100.fig
(Figure, optional): matplotlib Figure or Axes to add the marginal histograms to. Defaults to None.
Returns:
plt.Axes
: The matplotlib Axes to be used for the main plot.
function annotate_bars
annotate_bars(
ax: 'Axes | None' = None,
v_offset: 'float' = 10,
h_offset: 'float' = 0,
labels: 'Sequence[str | int | float] | None' = None,
fontsize: 'int' = 14,
y_max_headroom: 'float' = 1.2,
adjust_test_pos: 'bool' = False,
**kwargs: 'Any'
) → None
Annotate each bar in bar plot with a label.
Args:
ax
(Axes): The matplotlib axes to annotate.v_offset
(int): Vertical offset between the labels and the bars.h_offset
(int): Horizontal offset between the labels and the bars.labels
(list[str]): Labels used for annotating bars. If not provided, defaults to the y-value of each bar.fontsize
(int): Annotated text size in pts. Defaults to 14.y_max_headroom
(float): Will be multiplied with the y-value of the tallest bar to increase the y-max of the plot, thereby making room for text above all bars. Defaults to 1.2.adjust_test_pos
(bool): If True, use adjustText to prevent overlapping labels. Defaults to False.**kwargs
: Additional arguments (rotation, arrowprops, etc.) are passed to ax.annotate().
function annotate_metrics
annotate_metrics(
xs: 'ArrayLike',
ys: 'ArrayLike',
fig: 'AxOrFig | None' = None,
metrics: 'dict[str, float] | Sequence[str]' = ('MAE', 'R2'),
prefix: 'str' = '',
suffix: 'str' = '',
fmt: 'str' = '.3',
**kwargs: 'Any'
) → AnchoredText
Provide a set of x and y values of equal length and an optional Axes object on which to print the values’ mean absolute error and R^2 coefficient of determination.
Args:
xs
(array): x values.ys
(array): y values.fig
(plt.Axes | plt.Figure | go.Figure | None, optional): matplotlib Axes or Figure or plotly Figure on which to add the annotation. Defaults to None.metrics
(dict[str, float] | Sequence[str], optional): Metrics to show. Can be a subset of recognized keys MAE, R2, R2_adj, RMSE, MSE, MAPE or the names of sklearn.metrics.regression functions or any dict of metric names and values. Defaults to (“MAE”, “R2”).prefix
(str, optional): Title or other string to prepend to metrics. Defaults to "".suffix
(str, optional): Text to append after metrics. Defaults to "".fmt
(str, optional): f-string float format for metrics. Defaults to ‘.3’.**kwargs
: Additional arguments to pass to annotate().
Returns:
plt.Axes | plt.Figure | go.Figure
: The annotated figure.
function add_identity_line
add_identity_line(
fig: 'Figure | Figure | Axes',
line_kwds: 'dict[str, Any] | None' = None,
trace_idx: 'int' = 0,
**kwargs: 'Any'
) → Figure
Add a line shape to the background layer of a plotly figure spanning from smallest to largest x/y values in the trace specified by trace_idx.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes to add the identity line to.line_kwds
(dict[str, Any], optional): Keyword arguments for customizing the line shape will be passed to fig.add_shape(line=line_kwds). Defaults to dict(color=“gray”, width=1, dash=“dash”).trace_idx
(int, optional): Index of the trace to use for measuring x/y limits. Defaults to 0. Unused if kaleido package is installed and the figure’s actual x/y-range can be obtained from fig.full_figure_for_development(). Applies only to plotly figures.**kwargs
: Additional arguments are passed to fig.add_shape().
Raises:
TypeError
: If fig is neither a plotly nor a matplotlib figure or axes.ValueError
: If fig is a plotly figure and kaleido is not installed and trace_idx is out of range.
Returns:
Figure
: Figure with added identity line.
function add_best_fit_line
add_best_fit_line(
fig: 'Figure | Figure | Axes',
xs: 'ArrayLike' = (),
ys: 'ArrayLike' = (),
trace_idx: 'int' = 0,
line_kwds: 'dict[str, Any] | None' = None,
annotate_params: 'bool | dict[str, Any]' = True,
**kwargs: 'Any'
) → Figure
Add line of best fit according to least squares to a plotly or matplotlib figure.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes to add the best fit line to.xs
(array, optional): x-values to use for fitting. Defaults to () which means use the x-values of trace at trace_idx in fig.ys
(array, optional): y-values to use for fitting. Defaults to () which means use the y-values of trace at trace_idx in fig.trace_idx
(int, optional): Index of the trace to use for measuring x/y values for fitting if xs and ys are not provided. Defaults to 0.line_kwds
(dict[str, Any], optional): Keyword arguments for customizing the line shape. For plotly, will be passed to fig.add_shape(line=line_kwds). For matplotlib, will be passed to ax.plot(). Defaults to None.annotate_params
(dict[str, Any], optional): Pass dict to customize the annotation of the best fit line. Set to False to disable annotation. Defaults to True.**kwargs
: Additional arguments are passed to fig.add_shape() for plotly or ax.plot() for matplotlib.
Raises:
TypeError
: If fig is neither a plotly nor a matplotlib figure or axes.ValueError
: If fig is a plotly figure and xs and ys are not provided and trace_idx is out of range.
Returns:
Figure
: Figure with added best fit line.
function add_ecdf_line
add_ecdf_line(
fig: 'Figure',
values: 'ArrayLike' = (),
trace_idx: 'int' = 0,
trace_kwargs: 'dict[str, Any] | None' = None,
**kwargs: 'Any'
) → Figure
Add an empirical cumulative distribution function (ECDF) line to a plotly figure.
Support for matplotlib planned but not implemented. PRs welcome.
Args:
fig
(go.Figure): plotly figure to add the ECDF line to.values
(array, optional): Values to compute the ECDF from. Defaults to () which means use the x-values of trace at trace_idx in fig.trace_idx
(int, optional): Index of the trace whose x-values to use for computing the ECDF. Defaults to 0. Unused if values is not empty.trace_kwargs
(dict[str, Any], optional): Passed to trace_ecdf.update(). Defaults to None. Use e.g. to set trace name (default “Cumulative”) or line_color (default “gray”).**kwargs
: Passed to fig.add_trace().
Returns:
Figure
: Figure with added ECDF line.
module ptable
Various periodic table heatmaps with matplotlib and plotly.
Global Variables
- TYPE_CHECKING
- ELEM_CLASS_COLORS
function add_element_type_legend
add_element_type_legend(
data: 'DataFrame | Series | dict[str, list[float]]',
elem_class_colors: 'dict[str, str] | None' = None,
legend_kwargs: 'dict[str, Any] | None' = None
) → None
Add a legend to a matplotlib figure showing the colors of element types.
Args:
data
(pd.DataFrame | pd.Series | dict[str, list[float]]): Map from element to plot data. Used only to determine which element types are present.elem_class_colors
(dict[str, str]): Map from elementtypes to colors. E.g. {"Alkali Metal"
: “red”, “Noble Gas”: “blue”}.legend_kwargs
(dict): Keyword arguments passed to plt.legend() for customizing legend appearance. Defaults to None.
function count_elements
count_elements(
values: 'ElemValues',
count_mode: 'CountMode' = 'composition',
exclude_elements: 'Sequence[str]' = (),
fill_value: 'float | None' = 0
) → Series
Count element occurrence in list of formula strings or dict-like compositions. If passed values are already a map from element symbol to counts, ensure the data is a pd.Series filled with zero values for missing element symbols.
Provided as standalone function for external use or to cache long computations. Caching long element counts is done by refactoring ptable_heatmap(long_list_of_formulas) # slow to elem_counts = count_elements(long_list_of_formulas) # slow ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot
Args:
values
(dict[str, int | float] | pd.Series | list[str]): Iterable of composition strings/objects or map from element symbols to heatmap values. count_mode (‘(element|fractional|reduced)_composition’): Only used when values is a list of composition strings/objects.- composition (default): Count elements in each composition as is, i.e. without reduction or normalization.
- fractional_composition: Convert to normalized compositions in which the amounts of each species sum to before counting.
Example
: Fe2 O3 -> Fe0.4 O0.6- reduced_composition: Convert to reduced compositions (i.e. amounts normalized by greatest common denominator) before counting.
Example
: Fe4 P4 O16 -> Fe P O4.- occurrence: Count the number of times each element occurs in a list of formulas irrespective of compositions. E.g. [Fe2 O3, Fe O, Fe4 P4 O16]
counts to {Fe
: 3, O: 3, P: 1}.exclude_elements
(Sequence[str]): Elements to exclude from the count. Defaults to ().fill_value
(float | None): Value to fill in for missing elements. Defaults to 0.
Returns:
pd.Series
: Map element symbols to heatmap values.
function data_preprocessor
data_preprocessor(data: 'SupportedDataType') → DataFrame
Preprocess input data for ptable plotters, including:
- Convert all data types to pd.DataFrame.
- Impute missing values.
- Handle anomalies such as NaN, infinity.
- Write vmin/vmax as metadata into the DataFrame.
Returns:
pd.DataFrame
: The preprocessed DataFrame with element names as index and values as columns.
Example: data_dict: dict = {
“H”: 1.0,
“He”: [2.0, 4.0],
“Li”: [[6.0, 8.0], [10.0, 12.0]],
}
OR
data_df: pd.DataFrame = pd.DataFrame( data_dict.items(), columns=[“Element”, “Value”] ).set_index(“Element”)
OR
data_series: pd.Series = pd.Series(data_dict)
preprocess_data(data_dict/df/series)
Element Value 0 H [1.0, ] 1 He [2.0, 4.0] 2 Li [[6.0, 8.0], [10.0, 12.0]]
Metadata: vmin: 1.0 vmax: 12.0
function handle_missing_and_anomaly
handle_missing_and_anomaly(df: 'DataFrame') → DataFrame
Handle missing value (NaN) and anomaly (infinity).
Infinity would be replaced by vmax(∞) or vmin(-∞). Missing values would be handled by selected strategy:
- zero: impute with zeros
- mean: impute with mean value
TODO: finish this function
function ptable_heatmap
ptable_heatmap(
values: 'ElemValues',
log: 'bool | Normalize' = False,
ax: 'Axes | None' = None,
count_mode: 'CountMode' = 'composition',
cbar_title: 'str' = 'Element Count',
cbar_range: 'tuple[float | None, float | None] | None' = None,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.05),
cbar_kwargs: 'dict[str, Any] | None' = None,
colorscale: 'str' = 'viridis',
show_scale: 'bool' = True,
show_values: 'bool' = True,
infty_color: 'str' = 'lightskyblue',
na_color: 'str' = 'white',
heat_mode: "Literal['value', 'fraction', 'percent'] | None" = 'value',
fmt: 'str | Callable[, str] | None' = None,
cbar_fmt: 'str | Callable[, str] | None' = None,
text_color: 'str | tuple[str, str]' = 'auto',
exclude_elements: 'Sequence[str]' = (),
zero_color: 'str' = '#eff',
zero_symbol: 'str | float' = '-',
text_style: 'dict[str, Any] | None' = None,
label_font_size: 'int' = 16,
value_font_size: 'int' = 12,
tile_size: 'float | tuple[float, float]' = 0.9,
rare_earth_voffset: 'float' = 0.5,
**kwargs: 'Any'
) → Axes
Plot a heatmap across the periodic table of elements.
Args:
values
(dict[str, int | float] | pd.Series | list[str]): Map from element symbols to heatmap values or iterable of composition strings/objects.log
(bool | Normalize, optional): Whether colormap scale is log or linear. Can also take any matplotlib.colors.Normalize subclass such as SymLogNorm as custom color scale. Defaults to False.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None. count_mode (‘composition’ | ‘fractional_composition’ | ‘reduced_composition’): Reduce or normalize compositions before counting. See count_elements() for details. Only used when values is list of composition strings/objects.cbar_title
(str, optional): Color bar title. Defaults to “Element Count”.cbar_range
(tuple[float | None, float | None], optional): Color bar range. Can be used e.g. to create multiple plots with identical color bars for visual comparison. Defaults to automatic based on data range.cbar_coords
(tuple[float, float, float, float], optional): Color bar positionand size
: [x, y, width, height] anchored at lower left corner. Defaults to (0.18, 0.8, 0.42, 0.05).cbar_kwargs
(dict[str, Any], optional): Additional keyword arguments passed to fig.colorbar(). Defaults to None.colorscale
(str, optional): Matplotlib colormap name to use. Defaults to"viridis". See https
: //matplotlib.org/stable/users/explain/colors/colormaps for available options.show_scale
(bool, optional): Whether to show the color bar. Defaults to True.show_values
(bool, optional): Whether to show the heatmap values in each tile. Defaults to True.infty_color
: Color to use for elements with value infinity. Defaults to “lightskyblue”.na_color
: Color to use for elements with value infinity. Defaults to “white”.heat_mode
(“value” | “fraction” | “percent” | None): Whether to display heat values as is, normalized as a fraction of the total, as percentages or not at all (None). Defaults to “value”. “fraction” and “percent” can be used to make the colors in different ptable_heatmap() (and ptable_heatmap_ratio()) plots comparable.fmt
(str): f-string format option for tile values. Defaults to “.1%” (1 decimal place) if heat_mode=“percent” else “.3g”. Use e.g. “,.0f” to format values with thousands separators and no decimal places.cbar_fmt
(str): f-string format option to set a different color bar tick label format. Defaults to the above fmt.text_color
(str | tuple[str, str]): What color to use for element symbols and heat labels. Must be a valid color name, or a 2-tuple of names, one to use for the upper half of the color scale, one for the lower half. The special value “auto” applies “black” on the lower and “white” on the upper half of the color scale. “auto_reverse” does the opposite. Defaults to “auto”.exclude_elements
(list[str]): Elements to exclude from the heatmap. E.g. if oxygen overpowers everything, you can try log=True or exclude_elements=[“O”]. Defaults to ().zero_color
(str): Hex color or recognized matplotlib color name to use for elements with value zero. Defaults to “#eff” (light gray).zero_symbol
(str | float): Symbol to use for elements with value zero. Defaults to ”-“.text_style
(dict[str, Any]): Additional keyword arguments passed to plt.text(). Defaults to dict( ha=“center”, fontsize=label_font_size, fontweight=“semibold” )label_font_size
(int): Font size for element symbols. Defaults to 16.value_font_size
(int): Font size for heat values. Defaults to 12.tile_size
(float | tuple[float, float]): Size of each tile in the periodic table as a fraction of available space before touching neighboring tiles. 1 or (1, 1) means no gaps between tiles. Defaults to 0.9.cbar_coords
(tuple[float, float, float, float]): Color bar position and size: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.18, 0.8, 0.42, 0.05).rare_earth_voffset
(float): Vertical offset for lanthanides and actinides (row 6 and 7) from the rest of the periodic table. Defaults to 0.5.**kwargs
: Additional keyword arguments passed to plt.figure().
Returns:
plt.Axes
: matplotlib Axes with the heatmap.
function ptable_heatmap_splits
ptable_heatmap_splits(
data: 'DataFrame | Series | dict[str, list[list[float]]]',
colormap: 'str | None' = None,
start_angle: 'float' = 135,
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276a20>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.5),
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
cbar_title: 'str' = 'Values',
on_empty: "Literal['hide', 'show']" = 'hide',
ax_kwargs: 'dict[str, Any] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None,
plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
cbar_title_kwargs: 'dict[str, Any] | None' = None,
cbar_kwargs: 'dict[str, Any] | None' = None
) → Figure
Plot evenly-split heatmaps, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
{"Fe"
: [1, 2], “Co”: [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.colormap
(str): Matplotlib colormap name to use.start_angle
(float): The starting angle for the splits in degrees, and the split proceeds counter-clockwise (0 refers to the x-axis). Defaults to 135 degrees.ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title=“Periodic Table”, xlabel=“x-axis”, ylabel=“y-axis”, xlim=(0, 10), ylim=(0, 10), xscale=“linear”, yscale=“log”). See ax.set() docs for options:https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-setsymbol_text
(str | Callable[[Element], str]): Text to display foreach element symbol. Defaults to lambda elem
: elem.symbol.symbol_kwargs
(dict): Keyword arguments passed to plt.text() for element symbols. Defaults to None.symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner.cbar_coords
(tuple[float, float, float, float]): Colorbarposition and size
: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.25, 0.77, 0.35, 0.02).cbar_title
(str): Colorbar title. Defaults to “Values”.cbar_title_kwargs
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10).cbar_kwargs
(dict): Keyword arguments passed to fig.colorbar().on_empty
(‘hide’ | ‘show’): Whether to show or hide tiles for elements without data. Defaults to “hide”.plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.
Notes:
Default figsize is set to (0.75 n_groups, 0.75 n_periods).
Returns:
plt.Figure
: periodic table with a subplot in each element tile.
function ptable_heatmap_ratio
ptable_heatmap_ratio(
values_num: 'ElemValues',
values_denom: 'ElemValues',
count_mode: 'CountMode' = 'composition',
normalize: 'bool' = False,
cbar_title: 'str' = 'Element Ratio',
not_in_numerator: 'tuple[str, str] | None' = ('#eff', 'gray: not in 1st list'),
not_in_denominator: 'tuple[str, str] | None' = ('lightskyblue', 'blue: not in 2nd list'),
not_in_either: 'tuple[str, str] | None' = ('white', 'white: not in either'),
**kwargs: 'Any'
) → Axes
Display the ratio of two maps from element symbols to heat values or of two sets of compositions.
Args:
values_num
(dict[str, int | float] | pd.Series | list[str]): Map from element symbols to heatmap values or iterable of composition strings/objects in the numerator.values_denom
(dict[str, int | float] | pd.Series | list[str]): Map from element symbols to heatmap values or iterable of composition strings/objects in the denominator.normalize
(bool): Whether to normalize heatmap values so they sum to 1. Makes different ptable_heatmap_ratio plots comparable. Defaults to False. count_mode (‘composition’ | ‘fractional_composition’ | ‘reduced_composition’): Reduce or normalize compositions before counting. See count_elements() for details. Only used when values is list of composition strings/objects.cbar_title
(str): Title for the colorbar. Defaults to “Element Ratio”.not_in_numerator
(tuple[str, str]): Color and legend description used for elements missing from numerator. Defaults to('#eff', 'gray
: not in 1st list’).not_in_denominator
(tuple[str, str]): See not_in_numerator. Defaults to('lightskyblue', 'blue
: not in 2nd list’).not_in_either
(tuple[str, str]): See not_in_numerator. Defaults to('white', 'white
: not in either’).**kwargs
: Additional keyword arguments passed to ptable_heatmap().
Returns:
plt.Axes
: matplotlib Axes object
function ptable_heatmap_plotly
ptable_heatmap_plotly(
values: 'ElemValues',
count_mode: 'CountMode' = 'composition',
colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
show_scale: 'bool' = True,
show_values: 'bool' = True,
heat_mode: "Literal['value', 'fraction', 'percent'] | None" = 'value',
fmt: 'str | None' = None,
hover_props: 'Sequence[str] | dict[str, str] | None' = None,
hover_data: 'dict[str, str | int | float] | Series | None' = None,
font_colors: 'Sequence[str]' = (),
gap: 'float' = 5,
font_size: 'int | None' = None,
bg_color: 'str | None' = None,
color_bar: 'dict[str, Any] | None' = None,
cscale_range: 'tuple[float | None, float | None]' = (None, None),
exclude_elements: 'Sequence[str]' = (),
log: 'bool' = False,
fill_value: 'float | None' = None,
label_map: 'dict[str, str] | Callable[[str], str] | Literal[False] | None' = None,
**kwargs: 'Any'
) → Figure
Create a Plotly figure with an interactive heatmap of the periodic table. Supports hover tooltips with custom data or atomic reference data like electronegativity, atomic_radius, etc. See kwargs hover_data and hover_props, resp.
Args:
values
(dict[str, int | float] | pd.Series | list[str]): Map from element symbols to heatmap values e.g. dict(Fe=2, O=3) or iterable of composition strings or Pymatgen composition objects. count_mode (“composition” | “fractional_composition” | “reduced_composition”): Reduce or normalize compositions before counting. See count_elements() for details. Only used when values is list of composition strings/objects.colorscale
(str | list[str] | list[tuple[float, str]]): Color scale for heatmap. Defaults to ‘viridis’. See plotly.com/python/builtin-colorscales for names of other builtin color scales. Note “YlGn” and px.colors.sequential.YlGn are equivalent. Custom scales are specified as [“blue”, “red”] or [[0, “rgb(0,0,255)”], [0.5, “rgb(0,255,0)”], [1, “rgb(255,0,0)”]].show_scale
(bool): Whether to show a bar for the color scale. Defaults to True.show_values
(bool): Whether to show numbers on heatmap tiles. Defaults to True.heat_mode
(“value” | “fraction” | “percent” | None): Whether to display heat values as is (value), normalized as a fraction of the total, as percentages or not at all (None). Defaults to “value”. “fraction” and “percent” can be used to make the colors in different periodic table heatmap plots comparable.fmt
(str): f-string format option for heat labels. Defaults to “.1%” (1 decimal place) if heat_mode=“percent” else “.3g”.hover_props
(list[str] | dict[str, str]): Elemental properties to display in the hover tooltip. Can be a list of property names to display only the values themselves or a dict mapping names to what they should display as. E.g. dict(atomic_mass=“atomic weight”) will display as"atomic weight = {x}"
. Defaults to None.Available properties are
: symbol, row, column, name, atomic_number, atomic_mass, n_neutrons, n_protons, n_electrons, period, group, phase, radioactive, natural, metal, nonmetal, metalloid, type, atomic_radius, electronegativity, first_ionization, density, melting_point, boiling_point, number_of_isotopes, discoverer, year, specific_heat, n_shells, n_valence.hover_data
(dict[str, str | int | float] | pd.Series): Map from element symbols to additional data to display in the hover tooltip. dict(Fe=“this appears in the hover tooltip on a new line below the element name”). Defaults to None.font_colors
(list[str]): One color name or two for [min_color, max_color]. min_color is applied to annotations with heatmap values less than (max_val - min_val) / 2. Defaults to None, meaning auto-set to maximizecontrast with color scale
: white text for dark background and vice versa. swapped depending on the colorscale.gap
(float): Gap in pixels between tiles of the periodic table. Defaults to 5.font_size
(int): Element symbol and heat label text size. Any valid CSS size allowed. Defaults to automatic font size based on plot size. Element symbols will be bold and 1.5x this size.bg_color
(str): Plot background color. Defaults to “rgba(0, 0, 0, 0)“.color_bar
(dict[str, Any]): Plotly colorbar properties documented athttps
: //plotly.com/python/reference#heatmap-colorbar. Defaults to dict(orientation=“h”). Commonly used keys are:- title: colorbar title
- titleside: “top” | “bottom” | “right” | “left”
- tickmode: “array” | “auto” | “linear” | “log” | “date” | “category”
- tickvals: list of tick values
- ticktext: list of tick labels
- tickformat: f-string format option for tick labels
- len: fraction of plot height or width depending on orientation
- thickness: fraction of plot height or width depending on orientation
cscale_range
(tuple[float | None, float | None]): Colorbar range. Defaults to (None, None) meaning the range is automatically determined from the data.exclude_elements
(list[str]): Elements to exclude from the heatmap. E.g. if oxygen overpowers everything, you can do exclude_elements=[‘O’]. Defaults to ().log
(bool): Whether to use a logarithmic color scale. Defaults to False.Piece of advice
: colorscale=‘viridis’ and log=True go well together.fill_value
(float | None): Value to fill in for missing elements. Defaults to 0.label_map
(dict[str, str] | Callable[[str], str] | None): Map heat values (after string formatting) to target strings. Set to False to disable. Defaults to dict.fromkeys((np.nan, None, “nan”), ” ”) so as not to display ‘nan’ for missing values.**kwargs
: Additional keyword arguments passed to plotly.figure_factory.create_annotated_heatmap().
Returns:
Figure
: Plotly Figure object.
function ptable_hists
ptable_hists(
data: 'DataFrame | Series | dict[str, list[float]]',
bins: 'int' = 20,
colormap: 'str | None' = None,
hist_kwds: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
x_range: 'tuple[float | None, float | None] | None' = None,
symbol_kwargs: 'Any' = None,
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276ca0>,
cbar_title: 'str' = 'Values',
cbar_title_kwds: 'dict[str, Any] | None' = None,
cbar_kwds: 'dict[str, Any] | None' = None,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
log: 'bool' = False,
anno_kwds: 'dict[str, Any] | None' = None,
on_empty: "Literal['show', 'hide']" = 'hide',
color_elem_types: "Literal['symbol', 'background', 'both', False] | dict[str, str]" = 'background',
elem_type_legend: 'bool | dict[str, Any]' = True,
**kwargs: 'Any'
) → Figure
Plot small histograms for each element laid out in a periodic table.
Args:
data
(pd.DataFrame | pd.Series | dict[str, list[float]]): Map from elementsymbols to histogram values. E.g. if dict, {"Fe"
: [1, 2, 3], “O”: [4, 5]}. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols histograms are plotted from each column.bins
(int): Number of bins for the histograms. Defaults to 20.colormap
(str): Matplotlib colormap name to use. Defaults to None. See optionsat https
: //matplotlib.org/stable/users/explain/colors/colormaps.hist_kwds
(dict | Callable): Keywords passed to ax.hist() for each histogram. If callable, it is called with the histogram values for each element and should return a dict of keyword arguments. Defaults to None.cbar_coords
(tuple[float, float, float, float]): Color bar position and size: [x, y, width, height] anchored at lower left corner of the bar. Defaults to (0.25, 0.77, 0.35, 0.02).x_range
(tuple[float | None, float | None]): x-axis range for all histograms. Defaults to None.symbol_text
(str | Callable[[Element], str]): Text to display for each elementsymbol. Defaults to lambda elem
: elem.symbol.symbol_kwargs
(dict): Keyword arguments passed to plt.text() for element symbols. Defaults to None.cbar_title
(str): Color bar title. Defaults to “Histogram Value”.cbar_title_kwds
(dict): Keyword arguments passed to cbar.ax.set_title(). Defaults to dict(fontsize=12, pad=10).cbar_kwds
(dict): Keyword arguments passed to fig.colorbar().symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.8). (1, 1) is the upper right corner.log
(bool): Whether to log scale y-axis of each histogram. Defaults to False.anno_kwds
(dict): Keyword arguments passed to plt.annotate() for element annotations. Defaults to None. Useful for adding e.g. number of data points in each histogram. For that, useanno_kwds=lambda hist_vals
: dict(text=len(hist_vals)). Recognized keys are text, xy, xycoords, fontsize, and any other plt.annotate() keywords.on_empty
(‘hide’ | ‘show’): Whether to show or hide tiles for elements without data. Defaults to “hide”.color_elem_types
(‘symbol’ | ‘background’ | ‘both’ | False | dict): Whether to color element symbols, tile backgrounds, or both based on element type. If dict, it should map element types to colors. Defaults to “background”.elem_type_legend
(bool | dict): Whether to show a legend for element types. Defaults to True. If dict, used as kwargs to plt.legend(), e.g. toset the legend title, use {"title"
: “Element Types”}.**kwargs
: Additional keyword arguments passed to plt.subplots(). Defaults to dict(figsize=(0.75 n_columns, 0.75 n_rows)) with n_columns/n_rows the number of columns/rows in the periodic table.
Returns:
plt.Figure
: periodic table with a histogram in each element tile.
function ptable_scatters
ptable_scatters(
data: 'DataFrame | Series | dict[str, list[list[float]]]',
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276de0>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
on_empty: "Literal['hide', 'show']" = 'hide',
plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
child_args: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None
) → Figure
Make scatter plots for each element, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
{"Fe"
: [1, 2], “Co”: [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title=“Periodic Table”, xlabel=“x-axis”, ylabel=“y-axis”, xlim=(0, 10), ylim=(0, 10), xscale=“linear”, yscale=“log”). See ax.set() docs for options:https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-setsymbol_text
(str | Callable[[Element], str]): Text to display foreach element symbol. Defaults to lambda elem
: elem.symbol.symbol_kwargs
(dict): Keyword arguments passed to plt.text() for element symbols. Defaults to None.symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner.on_empty
(‘hide’ | ‘show’): Whether to show or hide tiles for elements without data. Defaults to “hide”.child_args
: Arguments to pass to the child plotter call.plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.
TODO: allow colormap with 3rd data dimension
function ptable_lines
ptable_lines(
data: 'DataFrame | Series | dict[str, list[list[float]]]',
symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276f20>,
symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
on_empty: "Literal['hide', 'show']" = 'hide',
plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
child_args: 'dict[str, Any] | None' = None,
ax_kwargs: 'dict[str, Any] | None' = None,
symbol_kwargs: 'dict[str, Any] | None' = None
) → Figure
Line plots for each element, nested inside a periodic table.
Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,
{"Fe"
: [1, 2], “Co”: [3, 4]}, where the 1st value would be plotted on the lower-left corner and the 2nd on the upper-right. If pd.Series, index is element symbols and values lists. If pd.DataFrame, column names are element symbols, plots are created from each column.ax_kwargs
(dict): Keyword arguments passed to ax.set() for each plot. Use to set x/y labels, limits, etc. Defaults to None. Example: dict(title=“Periodic Table”, xlabel=“x-axis”, ylabel=“y-axis”, xlim=(0, 10), ylim=(0, 10), xscale=“linear”, yscale=“log”). See ax.set() docs for options:https
: //matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set.html#matplotlib-axes-axes-setsymbol_text
(str | Callable[[Element], str]): Text to display foreach element symbol. Defaults to lambda elem
: elem.symbol.symbol_kwargs
(dict): Keyword arguments passed to plt.text() for element symbols. Defaults to None.symbol_pos
(tuple[float, float]): Position of element symbols relative to the lower left corner of each tile. Defaults to (0.5, 0.5). (1, 1) is the upper right corner.on_empty
(‘hide’ | ‘show’): Whether to show or hide tiles for elements without data. Defaults to “hide”.child_args
: Arguments to pass to the child plotter call.plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.
class PTableProjector
Project (nest) a custom plot into a periodic table.
Scopes mentioned in this plotter: plot: Refers to the global scope. ax: Refers to the axis where child plotter would plot. child: Refers to the child plotter itself, for example, ax.plot().
method __init__
__init__(
data: 'SupportedDataType',
colormap: 'str | Colormap | None',
plot_kwargs: 'dict[str, Any] | None' = None
) → None
Initialize a ptable projector.
Default figsize is set to (0.75 n_groups, 0.75 n_periods), and axes would be turned off by default.
Args:
data
(SupportedDataType): The data to be visualized.colormap
(str | Colormap | None): The colormap to use.plot_kwargs
(dict): Additional keyword arguments to pass to the plt.subplots function call.
property cmap
The global Colormap.
Returns:
Colormap
: The Colormap used.
property data
The preprocessed data.
Returns:
pd.DataFrame
: The preprocessed data.
property norm
Data min-max normalizer.
method add_child_plots
add_child_plots(
child_plotter: 'Callable[[axes, Any], None]',
child_args: 'dict[str, Any]',
ax_kwargs: 'dict[str, Any]',
on_empty: "Literal['hide', 'show']" = 'hide'
) → None
Add custom child plots to the periodic table grid.
Args:
child_plotter
: A callable for the child plotter.child_args
: Arguments to pass to the child plotter call.ax_kwargs
: Keyword arguments to pass to ax.set().on_empty
: Whether to “show” or “hide” tiles for elements without data.
method add_colorbar
add_colorbar(
title: 'str',
coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
cbar_kwargs: 'dict[str, Any] | None' = None,
title_kwargs: 'dict[str, Any] | None' = None
) → None
Add a global colorbar.
Args:
title
: Title for the colorbar.coords
: Coordinates of the colorbar (left, bottom, width, height). Defaults to (0.18, 0.8, 0.42, 0.02).cbar_kwargs
: Additional keyword arguments to pass to fig.colorbar().title_kwargs
: Additional keyword arguments for the colorbar title.
method add_ele_symbols
add_ele_symbols(
text: 'str | Callable[[Element], str]' = <function PTableProjector.<lambda> at 0x7f34a5276660>,
pos: 'tuple[float, float]' = (0.5, 0.5),
kwargs: 'dict[str, Any] | None' = None
) → None
Add element symbols for each tile.
Args:
text
(str | Callable): The text to add to the tile. If a callable, it should accept a pymatgen Element object and return a string. If a string, it can contain a format specifier for anelem
variable which will be replaced by the element.pos
: The position of the text relative to the axes.kwargs
: Additional keyword arguments to pass to theax.text
.
class ChildPlotters
Collect some pre-defined child plotters.
method line
line(ax: 'axes', data: 'SupportedValueType', **child_args: 'Any') → None
Line plotter.
Args:
ax
(plt.axes): The axis to plot on.data
(SupportedValueType): The values for to the child plotter.child_args
(dict): args to pass to the child plotter call
method rectangle
rectangle(
ax: 'axes',
data: 'SupportedValueType',
norm: 'Normalize',
cmap: 'Colormap',
start_angle: 'float'
) → None
Rectangle heatmap plotter, could be evenly split.
Could be evenly split, depending on the length of the data (could mix and match).
Args:
ax
(plt.axes): The axis to plot on.data
(SupportedValueType): The values for to the child plotter.norm
(Normalize): Normalizer for data-color mapping.cmap
(Colormap): Colormap used for value mapping.start_angle
(float): The starting angle for the splits in degrees, and the split proceeds counter-clockwise (0 refers to the x-axis).
method scatter
scatter(ax: 'axes', data: 'SupportedValueType', **child_args: 'Any') → None
Scatter plotter.
Args:
ax
(plt.axes): The axis to plot on.data
(SupportedValueType): The values for to the child plotter.child_args
(dict): args to pass to the child plotter call
module relevance
Plots for evaluating classifier performance.
Global Variables
- TYPE_CHECKING
function roc_curve
roc_curve(
targets: 'ArrayLike | str',
proba_pos: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None
) → tuple[float, Axes]
Plot the receiver operating characteristic curve of a binary classifier given target labels and predicted probabilities for the positive class.
Args:
targets
(array): Ground truth targets.proba_pos
(array): predicted probabilities for the positive class.df
(pd.DataFrame, optional): DataFrame with targets and proba_pos columns.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
tuple[float, ax]
: The classifier’s ROC-AUC and the plot’s matplotlib Axes.
function precision_recall_curve
precision_recall_curve(
targets: 'ArrayLike | str',
proba_pos: 'ArrayLike | str',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None
) → tuple[float, Axes]
Plot the precision recall curve of a binary classifier.
Args:
targets
(array): Ground truth targets.proba_pos
(array): predicted probabilities for the positive class.df
(pd.DataFrame, optional): DataFrame with targets and proba_pos columns.ax
(Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
tuple[float, ax]
: The classifier’s precision score and the matplotlib Axes.
module sankey
Sankey diagram for comparing distributions in two dataframe columns.
Global Variables
- TYPE_CHECKING
function sankey_from_2_df_cols
sankey_from_2_df_cols(
df: 'DataFrame',
cols: 'list[str]',
labels_with_counts: "bool | Literal['percent']" = True,
**kwargs: 'Any'
) → Figure
Plot two columns of a dataframe as a Plotly Sankey diagram.
Args:
df
(pd.DataFrame): Pandas dataframe.cols
(list[str]): 2-tuple of source and target column names. Source corresponds to left, target to right side of the diagram.labels_with_counts
(bool, optional): Whether to append value counts to node labels. Defaults to True.**kwargs
: Additional keyword arguments passed to plotly.graph_objects.Sankey.
Raises:
ValueError
: If len(cols) != 2.
Returns:
Figure
: Plotly figure containing the Sankey diagram.
module structure_viz
2D plots of pymatgen structures with matplotlib.
Global Variables
- TYPE_CHECKING
function unit_cell_to_lines
unit_cell_to_lines(cell: 'ArrayLike') → tuple[ArrayLike, ArrayLike, ArrayLike]
Convert lattice vectors to plot lines.
Args:
cell
(np.array): Lattice vectors.
Returns: tuple[np.array, np.array, np.array]:
- Lines
- z-indices that sort plot elements into out-of-plane layers
- lines used to plot the unit cell
function plot_structure_2d
plot_structure_2d(
struct: 'Structure',
ax: 'Axes | None' = None,
rotation: 'str' = '10x,10y,0z',
atomic_radii: 'float | dict[str, float] | None' = None,
colors: 'dict[str, str | list[float]] | None' = None,
scale: 'float' = 1,
show_unit_cell: 'bool' = True,
show_bonds: 'bool | NearNeighbors' = False,
site_labels: "bool | Literal['symbol', 'species'] | dict[str, str | float] | Sequence[str | float]" = True,
site_labels_bbox: 'dict[str, Any] | None' = None,
label_kwargs: 'dict[str, Any] | None' = None,
bond_kwargs: 'dict[str, Any] | None' = None,
standardize_struct: 'bool | None' = None,
axis: 'bool | str' = 'off'
) → Axes
Plot pymatgen structures in 2d with matplotlib.
Inspired by ASE’s ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib pymatviz aims to give similar output to ASE but supports disordered structures and avoids the conversion hassle of AseAtomsAdaptor.get_atoms(pmg_struct).
For example, these two snippets should give very similar output:
from pymatgen.ext.matproj import MPRester
mp_19017 = MPRester().get_structure_by_material_id("mp-19017")
# ASE
from ase.visualize.plot import plot_atoms
from pymatgen.io.ase import AseAtomsAdaptor
plot_atoms(AseAtomsAdaptor().get_atoms(mp_19017), rotation="10x,10y,0z", radii=0.5)
# pymatviz
from pymatviz import plot_structure_2d
plot_structure_2d(mp_19017)
Multiple structures in single figure example:
import matplotlib.pyplot as plt
from pymatgen.ext.matproj import MPRester
from pymatviz import plot_structure_2d
structures = [
MPRester().get_structure_by_material_id(f"mp-{idx}") for idx in range(1, 5)
]
fig, axs = plt.subplots(2, 2, figsize=(12, 12))
for struct, ax in zip(structures, axs.flat):
plot_structure_2d(struct, ax=ax)
Args:
struct
(Structure): Must be pymatgen instance.ax
(plt.Axes, optional): Matplotlib axes on which to plot. Defaults to None.rotation
(str, optional): Euler angles in degrees in the form ‘10x,20y,30z’ describing angle at which to view structure. Defaults to "".atomic_radii
(float | dict[str, float], optional): Either a scaling factor for default radii or map from element symbol to atomic radii. Defaults to covalent radii.colors
(dict[str, str | list[float]], optional): Map from element symbols to colors, either a named color (str) or rgb(a) values like (0.2, 0.3, 0.6).Defaults to JMol colors (https
: //jmol.sourceforge.net/jscolors).scale
(float, optional): Scaling of the plotted atoms and lines. Defaults to 1.show_unit_cell
(bool, optional): Whether to draw unit cell. Defaults to True.show_bonds
(bool | NearNeighbors, optional): Whether to draw bonds. If True, use pymatgen.analysis.local_env.CrystalNN to infer the structure’s connectivity. If False, don’t draw bonds. If a subclass of pymatgen.analysis.local_env.NearNeighbors, use that to determine connectivity. Options include VoronoiNN, MinimumDistanceNN, OpenBabelNN, CovalentBondNN, dtc. Defaults to True. site_labels (bool | “symbol” | “species” | dict[str, str | float] | Sequence): How to annotate lattice sites. If True, labels are element species (symbol + oxidation state). If a dict, should map species strings (or element symbols but looks for species string first) to labels. If a list, must be same length as the number of sites in the crystal. If a string, must be “symbol” or “species”. “symbol” hides the oxidation state, “species” shows it (equivalent to True). Defaults to True.site_labels_bbox
(dict, optional): Keyword arguments for matplotlib.text.Textbbox like {"facecolor"
: “white”, “alpha”: 0.5}. Defaults to None.label_kwargs
(dict, optional): Keyword arguments for matplotlib.text.Text like{"fontsize"
: 14}. Defaults to None.bond_kwargs
(dict, optional): Keyword arguments for the matplotlib.path.Path class used to draw chemical bonds. Allowed are edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle. Defaults to None.standardize_struct
(bool, optional): Whether to standardize the structure using SpacegroupAnalyzer(struct).get_conventional_standard_structure() before plotting. Defaults to False unless any fractional coordinates are negative, i.e. any crystal sites are outside the unit cell. Set this to False to disable this behavior which speeds up plotting for many structures.axis
(bool | str, optional): Whether/how to show plot axes. Defaults to “off”.See https
: //matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.axis for details.
Raises:
ValueError
: On invalid site_labels.
Returns:
plt.Axes
: matplotlib Axes instance with plotted structure.
class ExperimentalWarning
Used for experimental show_bonds feature.
module sunburst
Hierarchical multi-level pie charts (i.e. sunbursts).
E.g. for crystal symmetry distributions.
Global Variables
- TYPE_CHECKING
function spacegroup_sunburst
spacegroup_sunburst(
data: 'Sequence[int | str] | Series',
show_counts: "Literal['value', 'percent', False]" = False,
**kwargs: 'Any'
) → Figure
Generate a sunburst plot with crystal systems as the inner ring for a list of international space group numbers.
Hint: To hide very small labels, set a uniformtext minsize and mode=‘hide’. fig.update_layout(uniformtext=dict(minsize=9, mode=“hide”))
Args:
data
(list[int] | pd.Series): A sequence (list, tuple, pd.Series) of space group strings or numbers (from 1 - 230) or pymatgen structures.show_counts
(“value” | “percent” | False): Whether to display values below each labels on the sunburst.color_discrete_sequence
(list[str]): A list of 7 colors, one for each crystal system. Defaults to plotly.express.colors.qualitative.G10.**kwargs
: Additional keyword arguments passed to plotly.express.sunburst.
Returns:
Figure
: The Plotly figure.
module uncertainty
Visualizations for assessing the quality of model uncertainty estimates.
Global Variables
- TYPE_CHECKING
function qq_gaussian
qq_gaussian(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
df: 'DataFrame | None' = None,
ax: 'Axes | None' = None,
identity_line: 'bool | dict[str, Any]' = True
) → Axes
Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets y_true
and model predictions y_pred
.
Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line.
The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution.
Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot
Args:
y_true
(array | str): Ground truth targetsy_pred
(array | str): Model predictionsy_std
(array | dict[str, array] | str | list[str]): Model uncertainties either as array(s) (single or dict with labels if you have multiple sources of uncertainty) or column names in df.df
(pd.DataFrame, optional): DataFrame with y_true, y_pred and y_std columns.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.identity_line
(bool | dict[str, Any], optional): Whether to add a parity line (y = x). Defaults to True. Pass a dict to customize line properties.
Returns:
plt.Axes
: matplotlib Axes object
function get_err_decay
get_err_decay(
y_true: 'ArrayLike',
y_pred: 'ArrayLike',
n_rand: 'int' = 100
) → tuple[ArrayLike, ArrayLike]
Calculate the model’s error curve as samples are excluded from the calculation based on their absolute error.
Use in combination with get_std_decay to see what the error drop curve would look like if model error and uncertainty were perfectly rank-correlated.
Args:
y_true
(array): ground truth targetsy_pred
(array): model predictionsn_rand
(int, optional): Number of randomly ordered sample exclusions over which to average to estimate dummy performance. Defaults to 100.
Returns:
Tuple[array, array]
: Drop off in errors as data points are dropped based on model uncertainties and randomly, respectively.
function get_std_decay
get_std_decay(
y_true: 'ArrayLike',
y_pred: 'ArrayLike',
y_std: 'ArrayLike'
) → ArrayLike
Calculate the drop in model error as samples are excluded from the calculation based on the model’s uncertainty.
For model’s able to estimate their own uncertainty well, meaning predictions of larger error are associated with larger uncertainty, the error curve should fall off sharply at first as the highest-error points are discarded and slowly towards the end where only small-error samples with little uncertainty remain.
Note that even perfect model uncertainties would not mean this error drop curve coincides exactly with the one returned by get_err_decay as in some cases the model may have made an accurate prediction purely by chance in which case the error is small yet a good uncertainty estimate would still be large, leading the same sample to be excluded at different x-axis locations and thus the get_std_decay curve to lie higher.
Args:
y_true
(array): ground truth targetsy_pred
(array): model predictionsy_std
(array): model’s predicted uncertainties
Returns:
array
: Error decay as data points are excluded by order of largest to smallest model uncertainties.
function error_decay_with_uncert
error_decay_with_uncert(
y_true: 'ArrayLike | str',
y_pred: 'ArrayLike | str',
y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
df: 'DataFrame | None' = None,
n_rand: 'int' = 100,
percentiles: 'bool' = True,
ax: 'Axes | None' = None
) → Axes
Plot for assessing the quality of uncertainty estimates. If a model’s uncertainty is well calibrated, i.e. strongly correlated with its error, removing the most uncertain predictions should make the mean error decay similarly to how it decays when removing the predictions of largest error.
Args:
y_true
(array | str): Ground truth regression targets.y_pred
(array | str): Model predictions.y_std
(array | dict[str, ArrayLike] | str | list[str]): Model uncertainties. Can be single or multiple uncertainties (e.g. aleatoric/epistemic/total uncertainty) as dict.n_rand
(int, optional): Number of shuffles from which to compute std.dev. of error decay by random ordering. Defaults to 100.df
(pd.DataFrame, optional): DataFrame with y_true, y_pred and y_std columns.percentiles
(bool, optional): Whether the x-axis shows percentiles or number of remaining samples in the MAE calculation. Defaults to True.ax
(Axes): matplotlib Axes on which to plot. Defaults to None.
Note: If you’re not happy with the default y_max of 1.1 rand_mean, where rand_mean is mean of random sample exclusion, use ax.set(ylim=[None, some_value ax.get_ylim()[1]]).
Returns:
plt.Axes
: matplotlib Axes object with plotted model error drop curve based on excluding data points by order of large to small model uncertainties.
module utils
pymatviz utility functions.
Global Variables
- TYPE_CHECKING
- PKG_DIR
- ROOT
- TEST_FILES
- VALID_BACKENDS
- mpl_key
- plotly_key
- elements_csv
- ELEM_CLASS_COLORS
- missing_covalent_radius
- atomic_numbers
- element_symbols
- Z
- symbol
function pretty_label
pretty_label(key: 'str', backend: 'Backend') → str
Map metric keys to their pretty labels.
function crystal_sys_from_spg_num
crystal_sys_from_spg_num(spg: 'float') → CrystalSystem
Get the crystal system for an international space group number.
function df_to_arrays
df_to_arrays(
df: 'DataFrame | None',
*args: 'str | Sequence[str] | Sequence[ArrayLike]',
strict: 'bool' = True
) → list[ArrayLike | dict[str, ArrayLike]]
If df is None, this is a no-op: args are returned as-is. If df is a dataframe, all following args are used as column names and the column data returned as arrays (after dropping rows with NaNs in any column).
Args:
df
(pd.DataFrame | None): Optional pandas DataFrame.*args (list[ArrayLike | str])
: Arbitrary number of arrays or column names in df.strict
(bool, optional): If True, raise TypeError if df is not pd.DataFrame or None. If False, return args as-is. Defaults to True.
Raises:
ValueError
: If df is not None and any of the args is not a df column name.TypeError
: If df is not pd.DataFrame and not None.
Returns:
list[ArrayLike | dict[str, ArrayLike]]
: Array data for each column name or dictionary of column names and array data.
function bin_df_cols
bin_df_cols(
df_in: 'DataFrame',
bin_by_cols: 'Sequence[str]',
group_by_cols: 'Sequence[str]' = (),
n_bins: 'int | Sequence[int]' = 100,
bin_counts_col: 'str' = 'bin_counts',
kde_col: 'str' = '',
verbose: 'bool' = True
) → DataFrame
Bin columns of a DataFrame.
Args:
df_in
(pd.DataFrame): Input dataframe to bin.bin_by_cols
(Sequence[str]): Columns to bin.group_by_cols
(Sequence[str]): Additional columns to group by. Defaults to ().n_bins
(int): Number of bins to use. Defaults to 100.bin_counts_col
(str): Column name for bin counts. Defaults to “bin_counts”.kde_col
(str): Column name for KDE bin counts e.g. ‘kde_bin_counts’. Defaults to "" which means no KDE to speed things up.verbose
(bool): If True, report df length reduction. Defaults to True.
Returns:
pd.DataFrame
: Binned DataFrame.
function patch_dict
patch_dict(
dct: 'dict[Any, Any]',
*args: 'Any',
**kwargs: 'Any'
) → Generator[dict[Any, Any], None, None]
Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.
Useful e.g. for temporary plotly fig.layout mutations:
with patch_dict(fig.layout, showlegend=False): fig.write_image(“plot.pdf”)
Args:
dct
(dict): The dictionary to be patched.*args
: Only first element is read if present. A single dictionary containing the key-value pairs to patch.**kwargs
: The key-value pairs to patch, provided as keyword arguments.
Yields:
dict
: The patched dictionary incl. temporary updates.
function luminance
luminance(color: 'tuple[float, float, float]') → float
Compute the luminance of a color as in https://stackoverflow.com/a/596243.
Args:
color
(tuple[float, float, float]): RGB color tuple with values in [0, 1].
Returns:
float
: Luminance of the color.
function pick_bw_for_contrast
pick_bw_for_contrast(
color: 'tuple[float, float, float]',
text_color_threshold: 'float' = 0.7
) → str
Choose black or white text color for a given background color based on luminance.
Args:
color
(tuple[float, float, float]): RGB color tuple with values in [0, 1].text_color_threshold
(float, optional): Luminance threshold for choosing black or white text color. Defaults to 0.7.
Returns:
str
: “black” or “white” depending on the luminance of the background color.
function si_fmt
si_fmt(
val: 'float',
fmt: 'str' = '.1f',
sep: 'str' = '',
binary: 'bool' = False,
decimal_threshold: 'float' = 0.01
) → str
Convert large numbers into human readable format using SI prefixes.
Supports binary (1024) and metric (1000) mode.
https://nist.gov/pml/weights-and-measures/metric-si-prefixes
Args:
val
(int | float): Some numerical value to format.binary
(bool, optional): If True, scaling factor is 2^10 = 1024 else 1000. Defaults to False.fmt
(str): f-string format specifier. Configure precision and left/right padding in returned string. Defaults to “.1f”. Can be used to ensure leading or trailing whitespace for shorter numbers. Seehttps
: //docs.python.org/3/library/string.html#format-specification-mini-language.sep
(str): Separator between number and postfix. Defaults to "".decimal_threshold
(float): abs(value) below 1 but above this threshold will be left as decimals. Only below this threshold is a greek suffix added (milli, micro, etc.). Defaults to 0.01. i.e. 0.01 -> “0.01” while 0.0099 -> “9.9m”. Setting decimal_threshold=0.1 would format 0.01 as “10m” and leave 0.1 as is.
Returns:
str
: Formatted number.
function styled_html_tag
styled_html_tag(text: 'str', tag: 'str' = 'span', style: 'str' = '') → str
Wrap text in a span with custom style.
Style defaults to decreased font size and weight e.g. to display units in plotly labels and annotations.
Args:
text
(str): Text to wrap in span.tag
(str, optional): HTML tag name. Defaults to “span”.style
(str, optional): CSS style string. Defaults to"font-size
: 0.8em; font-weight: lighter;“.
function validate_fig
validate_fig(func: 'Callable[P, R]') → Callable[P, R]
Decorator to validate the type of fig keyword argument in a function.
function annotate
annotate(
text: 'str',
fig: 'AxOrFig | None' = None,
color: 'str' = 'black',
**kwargs: 'Any'
) → AxOrFig
Annotate a matplotlib or plotly figure.
Args:
text
(str): The text to use for annotation.fig
(plt.Axes | plt.Figure | go.Figure | None, optional): The matplotlib Axes, Figure or plotly Figure to annotate. If None, the current matplotlib Axes will be used. Defaults to None.color
(str, optional): The color of the text. Defaults to “black”.**kwargs
: Additional arguments to pass to matplotlib’s AnchoredText or plotly’s fig.add_annotation().
Returns:
plt.Axes | plt.Figure | go.Figure
: The annotated figure.
function get_fig_xy_range
get_fig_xy_range(
fig: 'Figure | Figure | Axes',
trace_idx: 'int' = 0
) → tuple[tuple[float, float], tuple[float, float]]
Get the x and y range of a plotly or matplotlib figure.
Args:
fig
(go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes.trace_idx
(int, optional): Index of the trace to use for measuring x/y limits. Defaults to 0. Unused if kaleido package is installed and the figure’s actual x/y-range can be obtained from fig.full_figure_for_development().
Returns:
tuple[float, float, float, float]
: The x and y range of the figure in the format (x_min, x_max, y_min, y_max).