« home

API

source link

module correlation

Plot distributions of correlation matrix eigenvalues.

Global Variables


source link

function marchenko_pastur_pdf

marchenko_pastur_pdf(x: 'float', gamma: 'float', sigma: 'float' = 1)float

Generate Marchenko-Pastur probability distribution which describes the density of singular values of large rectangular random matrices.

See https://wikipedia.org/wiki/Marchenko-Pastur_distribution.

By comparing the eigenvalue distribution of a correlation matrix to this PDF, one can gauge the significance of correlations.

Args:

Returns:


source link

function marchenko_pastur

marchenko_pastur(
    matrix: 'ArrayLike',
    gamma: 'float',
    sigma: 'float' = 1,
    filter_high_evals: 'bool' = False,
    ax: 'Axes | None' = None
) → Axes

Plot the eigenvalue distribution of a symmetric matrix (usually a correlation matrix) against the Marchenko Pastur distribution.

The probability of a random matrix having eigenvalues >= (1 + sqrt(gamma))^2 in the absence of any signal is vanishingly small. Thus, if eigenvalues larger than that appear, they correspond to statistically significant signals.

Args:

Returns:

source link

module cumulative

Plot the cumulative distribution of residuals and absolute errors.

Global Variables


source link

function cumulative_residual

cumulative_residual(
    res: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution for the residuals (y - mu).

Args:

Returns:


source link

function cumulative_error

cumulative_error(
    abs_err: 'ArrayLike',
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Plot the empirical cumulative distribution of the absolute errors.

abs(y_true - y_pred).

Args:

Returns:

source link

module histograms

Histograms and bar charts.

Global Variables


source link

function true_pred_hist

true_pred_hist(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    cmap: 'str' = 'hot',
    truth_color: 'str' = 'blue',
    true_label: 'str' = '$y_\mathrm{true}$',
    pred_label: 'str' = '$y_\mathrm{pred}$',
    **kwargs: 'Any'
) → Axes

Plot a histogram of model predictions with bars colored by the mean uncertainty of predictions in that bin. Overlaid by a more transparent histogram of ground truth values.

Args:

Returns:


source link

function spacegroup_hist

spacegroup_hist(
    data: 'Sequence[int | str | Structure] | Series',
    show_counts: 'bool' = True,
    xticks: "Literal['all', 'crys_sys_edges'] | int" = 20,
    show_empty_bins: 'bool' = False,
    ax: 'Axes | None' = None,
    backend: 'Backend' = 'plotly',
    text_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Axes | Figure

Plot a histogram of spacegroups shaded by crystal system.

Args:

Returns:


source link

function elements_hist

elements_hist(
    formulas: 'ElemValues',
    count_mode: 'CountMode' = 'composition',
    log: 'bool' = False,
    keep_top: 'int | None' = None,
    ax: 'Axes | None' = None,
    bar_values: "Literal['percent', 'count'] | None" = 'percent',
    h_offset: 'int' = 0,
    v_offset: 'int' = 10,
    rotation: 'int' = 45,
    **kwargs: 'Any'
) → Axes

Plot a histogram of elements (e.g. to show occurrence in a dataset).

Adapted from https://github.com/kaaiian/ML_figures (https://git.io/JmbaI).

Args:

Returns:

source link

module io

I/O utilities for saving figures and dataframes to various image formats.

Global Variables


source link

function save_fig

save_fig(
    fig: 'Figure | Figure | Axes',
    path: 'str',
    plotly_config: 'dict[str, Any] | None' = None,
    env_disable: 'Sequence[str]' = ('CI',),
    pdf_sleep: 'float' = 0.6,
    style: 'str' = '',
    prec: 'int | None' = None,
    template: 'str | None' = None,
    **kwargs: 'Any'
)None

Write a plotly or matplotlib figure to disk (as HTML/PDF/SVG/…).

If the file is has .svelte extension, insert {...$$props} into the figure’s top-level div so it can be later styled and customized from Svelte code.

Args:


source link

function save_and_compress_svg

save_and_compress_svg(fig: 'Figure | Figure | Axes', filename: 'str')None

Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with svgo CLI if available in PATH.

Args:

Raises:


source link

function df_to_pdf

df_to_pdf(
    styler: 'Styler',
    file_path: 'str | Path',
    crop: 'bool' = True,
    size: 'str | None' = None,
    style: 'str' = '',
    styler_css: 'bool | dict[str, str]' = True,
    **kwargs: 'Any'
)None

Export a pandas Styler to PDF with WeasyPrint.

Args:


source link

function normalize_and_crop_pdf

normalize_and_crop_pdf(
    file_path: 'str | Path',
    on_gs_not_found: "Literal['ignore', 'warn', 'error']" = 'warn'
)None

Normalize a PDF using Ghostscript and then crop it. Without gs normalization, pdfCropMargins sometimes corrupts the PDF.

Args:


source link

function df_to_html_table

df_to_html_table(
    styler: 'Styler',
    file_path: 'str | Path | None' = None,
    inline_props: 'str | None' = '',
    script: 'str | None' = '',
    styles: 'str | None' = 'table { overflow: scroll; max-width: 100%; display: block; }\ntable {\n    scrollbar-width: none;  /* Firefox */\n}\ntable::-webkit-scrollbar {\n    display: none;  /* Safari and Chrome */\n}',
    styler_css: 'bool | dict[str, str]' = True,
    sortable: 'bool' = True,
    post_process: 'Callable[[str], str] | None' = None,
    **kwargs: 'Any'
)str

Convert a pandas Styler to a svelte table.

Args:

Returns:


source link

class BufferedIOBase

Base class for buffered IO objects.

The main difference with RawIOBase is that the read() method supports omitting the size argument, and does not have a default implementation that defers to readinto().

In addition, read(), readinto() and write() may raise BlockingIOError if the underlying raw stream is in non-blocking mode and not ready; unlike their raw counterparts, they will never return None.

A typical implementation should not inherit from a RawIOBase implementation, but wrap one.


source link

class IOBase

The abstract base class for all I/O classes.

This class provides dummy implementations for many methods that derived classes can override selectively; the default implementations represent a file that cannot be read, written or seeked.

Even though IOBase does not declare read, readinto, or write because their signatures will vary, implementations and clients should consider those methods part of the interface. Also, implementations may raise UnsupportedOperation when operations they do not support are called.

The basic type used for binary data read from or written to a file is bytes. Other bytes-like objects are accepted as method arguments too. In some cases (such as readinto), a writable object is required. Text I/O classes work with str data.

Note that calling any method (except additional calls to close(), which are ignored) on a closed stream should raise a ValueError.

IOBase (and its subclasses) support the iterator protocol, meaning that an IOBase object can be iterated over yielding the lines in a stream.

IOBase also supports the :keyword:with statement. In this example, fp is closed after the suite of the with statement is complete:

with open(‘spam.txt’, ‘r’) as fp: fp.write(‘Spam and eggs!‘)


source link

class RawIOBase

Base class for raw binary I/O.


source link

class TextIOBase

Base class for text I/O.

This class provides a character and line based interface to stream I/O. There is no readinto method because Python’s character strings are immutable.


source link

class UnsupportedOperation


source link

class TqdmDownload

Progress bar for urllib.request.urlretrieve file download.

Adapted from official TqdmUpTo example. See https://github.com/tqdm/tqdm/blob/4c956c20b83be4312460fc0c4812eeb3fef5e7df/README.rst#hooks-and-callbacks

source link

method __init__

__init__(*args: 'Any', **kwargs: 'Any')None

Sets default values appropriate for file downloads for unit, unit_scale, unit_divisor, miniters, desc.


property format_dict

Public API for read-only member access.


source link

method update_to

update_to(
    n_blocks: 'int' = 1,
    block_size: 'int' = 1,
    total_size: 'int | None' = None
)bool | None

Update hook for urlretrieve.

Args:

Returns:

source link

module parity

Parity, residual and density plots.

Global Variables


source link

function hist_density

hist_density(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    sort: 'bool' = True,
    bins: 'int' = 100,
    method: 'str' = 'nearest'
)tuple[ArrayLike, ArrayLike, ArrayLike]

Return an approximate density of 2d points.

Args:

Returns:


source link

function density_scatter

density_scatter(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    log_density: 'bool' = True,
    hist_density_kwargs: 'dict[str, Any] | None' = None,
    color_bar: 'bool | dict[str, Any]' = True,
    xlabel: 'str | None' = None,
    ylabel: 'str | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    stats: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Axes

Scatter plot colored (and optionally sorted) by density.

Args:

Returns: plt.Axes:


source link

function scatter_with_err_bar

scatter_with_err_bar(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    xerr: 'ArrayLike | None' = None,
    yerr: 'ArrayLike | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    title: 'str | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot with optional x- and/or y-error bars. Useful when passing model uncertainties as yerr=y_std for checking if uncertainty correlates with error, i.e. if points farther from the parity line have larger uncertainty.

Args:

Returns:


source link

function density_hexbin

density_hexbin(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    weights: 'ArrayLike | None' = None,
    identity_line: 'bool | dict[str, Any]' = True,
    best_fit_line: 'bool | dict[str, Any]' = True,
    xlabel: 'str' = 'Actual',
    ylabel: 'str' = 'Predicted',
    cbar_label: 'str | None' = 'Density',
    cbar_coords: 'tuple[float, float, float, float]' = (0.95, 0.03, 0.03, 0.7),
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by point density or by density in third dimension passed as weights.

Args:

Returns:


source link

function density_scatter_with_hist

density_scatter_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    ax: 'Axes | None' = None,
    **kwargs: 'Any'
) → Axes

Scatter plot colored (and optionally sorted) by density with histograms along each dimension.


source link

function density_hexbin_with_hist

density_hexbin_with_hist(
    x: 'ArrayLike | str',
    y: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    **kwargs: 'Any'
) → Axes

Hexagonal-grid scatter plot colored by density or by third dimension passed color_by with histograms along each dimension.


source link

function residual_vs_actual

residual_vs_actual(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    xlabel: 'str' = 'Actual value',
    ylabel: 'str' = 'Residual ($y_\mathrm{true} - y_\mathrm{pred}$)',
    **kwargs: 'Any'
) → Axes

Plot targets on the x-axis vs residuals (y_err = y_true - y_pred) on the y-axis.

Args:

Returns:

source link

module phonons

Plotting functions for pymatgen phonon band structures and density of states.

Global Variables


source link

function pretty_sym_point

pretty_sym_point(symbol: 'str')str

Convert a symbol to a pretty-printed version.


source link

function get_band_xaxis_ticks

get_band_xaxis_ticks(
    band_struct: 'PhononBands',
    branches: 'Sequence[str] | set[str]' = ()
)tuple[list[float], list[str]]

Get all ticks and labels for a band structure plot.

Returns:


source link

function plot_phonon_bands

plot_phonon_bands(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    line_kwargs: 'dict[str, Any] | None' = None,
    branches: 'Sequence[str]' = (),
    branch_mode: 'BranchMode' = 'union',
    shaded_ys: 'dict[tuple[YMin, YMax], dict[str, Any]] | bool | None' = None,
    **kwargs: 'Any'
) → Figure

Plot single or multiple pymatgen band structures using Plotly, focusing on the minimum set of overlapping branches.

Warning: Only tested with phonon band structures so far but plan is to extend to electronic band structures.

Args: band_structs (PhononBandStructureSymmLine | dict[str, PhononBandStructure]): Single BandStructureSymmLine or PhononBandStructureSymmLine object or a dict with labels mapped to multiple such objects.

Returns:


source link

function plot_phonon_dos

plot_phonon_dos(
    doses: 'PhononDos | dict[str, PhononDos]',
    stack: 'bool' = False,
    sigma: 'float' = 0,
    units: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz',
    normalize: "Literal['max', 'sum', 'integral'] | None" = None,
    last_peak_anno: 'str | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS using Plotly.

Args:

Returns:


source link

function convert_frequencies

convert_frequencies(
    frequencies: 'ndarray',
    unit: "Literal['THz', 'eV', 'meV', 'Ha', 'cm-1']" = 'THz'
) → ndarray

Convert frequencies from THz to specified units.

Args:

Returns:


source link

function plot_phonon_bands_and_dos

plot_phonon_bands_and_dos(
    band_structs: 'PhononBands | dict[str, PhononBands]',
    doses: 'PhononDos | dict[str, PhononDos]',
    bands_kwargs: 'dict[str, Any] | None' = None,
    dos_kwargs: 'dict[str, Any] | None' = None,
    subplot_kwargs: 'dict[str, Any] | None' = None,
    all_line_kwargs: 'dict[str, Any] | None' = None,
    per_line_kwargs: 'dict[str, dict[str, Any]] | None' = None,
    **kwargs: 'Any'
) → Figure

Plot phonon DOS and band structure using Plotly.

Args:

Returns:


source link

class PhononDBDoc

Dataclass for phonon DB docs.

source link

method __init__

__init__(
    structure: 'Structure',
    phonon_bandstructure: 'PhononBands',
    phonon_dos: 'PhononDos',
    free_energies: 'list[float]',
    internal_energies: 'list[float]',
    heat_capacities: 'list[float]',
    entropies: 'list[float]',
    temps: 'list[float] | None' = None,
    has_imaginary_modes: 'bool | None' = None,
    primitive: 'Structure | None' = None,
    supercell: 'list[list[int]] | None' = None,
    nac_params: 'dict[str, Any] | None' = None,
    thermal_displacement_data: 'dict[str, Any] | None' = None,
    mp_id: 'str | None' = None,
    formula: 'str | None' = None
)None
source link

module powerups

Powerups/enhancements such as parity lines, annotations and marginals for matplotlib and plotly figures.

Global Variables


source link

function with_marginal_hist

with_marginal_hist(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    cell: 'GridSpec | None' = None,
    bins: 'int' = 100,
    fig: 'Figure | Axes | None' = None
) → Axes

Call before creating a plot and use the returned ax_main for all subsequent plotting ops to create a grid of plots with the main plot in the lower left and narrow histograms along its x- and/or y-axes displayed above and near the right edge.

Args:

Returns:


source link

function annotate_bars

annotate_bars(
    ax: 'Axes | None' = None,
    v_offset: 'float' = 10,
    h_offset: 'float' = 0,
    labels: 'Sequence[str | int | float] | None' = None,
    fontsize: 'int' = 14,
    y_max_headroom: 'float' = 1.2,
    adjust_test_pos: 'bool' = False,
    **kwargs: 'Any'
)None

Annotate each bar in bar plot with a label.

Args:


source link

function annotate_metrics

annotate_metrics(
    xs: 'ArrayLike',
    ys: 'ArrayLike',
    fig: 'AxOrFig | None' = None,
    metrics: 'dict[str, float] | Sequence[str]' = ('MAE', 'R2'),
    prefix: 'str' = '',
    suffix: 'str' = '',
    fmt: 'str' = '.3',
    **kwargs: 'Any'
) → AnchoredText

Provide a set of x and y values of equal length and an optional Axes object on which to print the values’ mean absolute error and R^2 coefficient of determination.

Args:

Returns:


source link

function add_identity_line

add_identity_line(
    fig: 'Figure | Figure | Axes',
    line_kwds: 'dict[str, Any] | None' = None,
    trace_idx: 'int' = 0,
    **kwargs: 'Any'
) → Figure

Add a line shape to the background layer of a plotly figure spanning from smallest to largest x/y values in the trace specified by trace_idx.

Args:

Raises:

Returns:


source link

function add_best_fit_line

add_best_fit_line(
    fig: 'Figure | Figure | Axes',
    xs: 'ArrayLike' = (),
    ys: 'ArrayLike' = (),
    trace_idx: 'int' = 0,
    line_kwds: 'dict[str, Any] | None' = None,
    annotate_params: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Figure

Add line of best fit according to least squares to a plotly or matplotlib figure.

Args:

Raises:

Returns:


source link

function add_ecdf_line

add_ecdf_line(
    fig: 'Figure',
    values: 'ArrayLike' = (),
    trace_idx: 'int' = 0,
    trace_kwargs: 'dict[str, Any] | None' = None,
    **kwargs: 'Any'
) → Figure

Add an empirical cumulative distribution function (ECDF) line to a plotly figure.

Support for matplotlib planned but not implemented. PRs welcome.

Args:

Returns:

source link

module ptable

Various periodic table heatmaps with matplotlib and plotly.

Global Variables


source link

function add_element_type_legend

add_element_type_legend(
    data: 'DataFrame | Series | dict[str, list[float]]',
    elem_class_colors: 'dict[str, str] | None' = None,
    legend_kwargs: 'dict[str, Any] | None' = None
)None

Add a legend to a matplotlib figure showing the colors of element types.

Args:


source link

function count_elements

count_elements(
    values: 'ElemValues',
    count_mode: 'CountMode' = 'composition',
    exclude_elements: 'Sequence[str]' = (),
    fill_value: 'float | None' = 0
) → Series

Count element occurrence in list of formula strings or dict-like compositions. If passed values are already a map from element symbol to counts, ensure the data is a pd.Series filled with zero values for missing element symbols.

Provided as standalone function for external use or to cache long computations. Caching long element counts is done by refactoring ptable_heatmap(long_list_of_formulas) # slow to elem_counts = count_elements(long_list_of_formulas) # slow ptable_heatmap(elem_counts) # fast, only rerun this line to update the plot

Args:

Returns:


source link

function data_preprocessor

data_preprocessor(data: 'SupportedDataType') → DataFrame

Preprocess input data for ptable plotters, including:

Returns:

Example: data_dict: dict = { “H”: 1.0, “He”: [2.0, 4.0], “Li”: [[6.0, 8.0], [10.0, 12.0]], }

OR

data_df: pd.DataFrame = pd.DataFrame( data_dict.items(), columns=[“Element”, “Value”] ).set_index(“Element”)

OR

data_series: pd.Series = pd.Series(data_dict)

preprocess_data(data_dict/df/series)

Element Value 0 H [1.0, ] 1 He [2.0, 4.0] 2 Li [[6.0, 8.0], [10.0, 12.0]]

Metadata: vmin: 1.0 vmax: 12.0


source link

function handle_missing_and_anomaly

handle_missing_and_anomaly(df: 'DataFrame') → DataFrame

Handle missing value (NaN) and anomaly (infinity).

Infinity would be replaced by vmax(∞) or vmin(-∞). Missing values would be handled by selected strategy:

TODO: finish this function


source link

function ptable_heatmap

ptable_heatmap(
    values: 'ElemValues',
    log: 'bool | Normalize' = False,
    ax: 'Axes | None' = None,
    count_mode: 'CountMode' = 'composition',
    cbar_title: 'str' = 'Element Count',
    cbar_range: 'tuple[float | None, float | None] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.05),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    colorscale: 'str' = 'viridis',
    show_scale: 'bool' = True,
    show_values: 'bool' = True,
    infty_color: 'str' = 'lightskyblue',
    na_color: 'str' = 'white',
    heat_mode: "Literal['value', 'fraction', 'percent'] | None" = 'value',
    fmt: 'str | Callable[, str] | None' = None,
    cbar_fmt: 'str | Callable[, str] | None' = None,
    text_color: 'str | tuple[str, str]' = 'auto',
    exclude_elements: 'Sequence[str]' = (),
    zero_color: 'str' = '#eff',
    zero_symbol: 'str | float' = '-',
    text_style: 'dict[str, Any] | None' = None,
    label_font_size: 'int' = 16,
    value_font_size: 'int' = 12,
    tile_size: 'float | tuple[float, float]' = 0.9,
    rare_earth_voffset: 'float' = 0.5,
    **kwargs: 'Any'
) → Axes

Plot a heatmap across the periodic table of elements.

Args:

Returns:


source link

function ptable_heatmap_splits

ptable_heatmap_splits(
    data: 'DataFrame | Series | dict[str, list[list[float]]]',
    colormap: 'str | None' = None,
    start_angle: 'float' = 135,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276a20>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.5),
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_title: 'str' = 'Values',
    on_empty: "Literal['hide', 'show']" = 'hide',
    ax_kwargs: 'dict[str, Any] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None,
    plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
    cbar_title_kwargs: 'dict[str, Any] | None' = None,
    cbar_kwargs: 'dict[str, Any] | None' = None
) → Figure

Plot evenly-split heatmaps, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

Notes:

Default figsize is set to (0.75 n_groups, 0.75 n_periods).

Returns:


source link

function ptable_heatmap_ratio

ptable_heatmap_ratio(
    values_num: 'ElemValues',
    values_denom: 'ElemValues',
    count_mode: 'CountMode' = 'composition',
    normalize: 'bool' = False,
    cbar_title: 'str' = 'Element Ratio',
    not_in_numerator: 'tuple[str, str] | None' = ('#eff', 'gray: not in 1st list'),
    not_in_denominator: 'tuple[str, str] | None' = ('lightskyblue', 'blue: not in 2nd list'),
    not_in_either: 'tuple[str, str] | None' = ('white', 'white: not in either'),
    **kwargs: 'Any'
) → Axes

Display the ratio of two maps from element symbols to heat values or of two sets of compositions.

Args:

Returns:


source link

function ptable_heatmap_plotly

ptable_heatmap_plotly(
    values: 'ElemValues',
    count_mode: 'CountMode' = 'composition',
    colorscale: 'str | Sequence[str] | Sequence[tuple[float, str]]' = 'viridis',
    show_scale: 'bool' = True,
    show_values: 'bool' = True,
    heat_mode: "Literal['value', 'fraction', 'percent'] | None" = 'value',
    fmt: 'str | None' = None,
    hover_props: 'Sequence[str] | dict[str, str] | None' = None,
    hover_data: 'dict[str, str | int | float] | Series | None' = None,
    font_colors: 'Sequence[str]' = (),
    gap: 'float' = 5,
    font_size: 'int | None' = None,
    bg_color: 'str | None' = None,
    color_bar: 'dict[str, Any] | None' = None,
    cscale_range: 'tuple[float | None, float | None]' = (None, None),
    exclude_elements: 'Sequence[str]' = (),
    log: 'bool' = False,
    fill_value: 'float | None' = None,
    label_map: 'dict[str, str] | Callable[[str], str] | Literal[False] | None' = None,
    **kwargs: 'Any'
) → Figure

Create a Plotly figure with an interactive heatmap of the periodic table. Supports hover tooltips with custom data or atomic reference data like electronegativity, atomic_radius, etc. See kwargs hover_data and hover_props, resp.

Args:

Returns:


source link

function ptable_hists

ptable_hists(
    data: 'DataFrame | Series | dict[str, list[float]]',
    bins: 'int' = 20,
    colormap: 'str | None' = None,
    hist_kwds: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
    cbar_coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    x_range: 'tuple[float | None, float | None] | None' = None,
    symbol_kwargs: 'Any' = None,
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276ca0>,
    cbar_title: 'str' = 'Values',
    cbar_title_kwds: 'dict[str, Any] | None' = None,
    cbar_kwds: 'dict[str, Any] | None' = None,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    log: 'bool' = False,
    anno_kwds: 'dict[str, Any] | None' = None,
    on_empty: "Literal['show', 'hide']" = 'hide',
    color_elem_types: "Literal['symbol', 'background', 'both', False] | dict[str, str]" = 'background',
    elem_type_legend: 'bool | dict[str, Any]' = True,
    **kwargs: 'Any'
) → Figure

Plot small histograms for each element laid out in a periodic table.

Args:

Returns:


source link

function ptable_scatters

ptable_scatters(
    data: 'DataFrame | Series | dict[str, list[list[float]]]',
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276de0>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    on_empty: "Literal['hide', 'show']" = 'hide',
    plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
    child_args: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None
) → Figure

Make scatter plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,

TODO: allow colormap with 3rd data dimension


source link

function ptable_lines

ptable_lines(
    data: 'DataFrame | Series | dict[str, list[list[float]]]',
    symbol_text: 'str | Callable[[Element], str]' = <function <lambda> at 0x7f34a5276f20>,
    symbol_pos: 'tuple[float, float]' = (0.5, 0.8),
    on_empty: "Literal['hide', 'show']" = 'hide',
    plot_kwargs: 'dict[str, Any] | Callable[[Sequence[float]], dict[str, Any]] | None' = None,
    child_args: 'dict[str, Any] | None' = None,
    ax_kwargs: 'dict[str, Any] | None' = None,
    symbol_kwargs: 'dict[str, Any] | None' = None
) → Figure

Line plots for each element, nested inside a periodic table.

Args: data (pd.DataFrame | pd.Series | dict[str, list[list[float]]]): Map from element symbols to plot data. E.g. if dict,


source link

class PTableProjector

Project (nest) a custom plot into a periodic table.

Scopes mentioned in this plotter: plot: Refers to the global scope. ax: Refers to the axis where child plotter would plot. child: Refers to the child plotter itself, for example, ax.plot().

source link

method __init__

__init__(
    data: 'SupportedDataType',
    colormap: 'str | Colormap | None',
    plot_kwargs: 'dict[str, Any] | None' = None
)None

Initialize a ptable projector.

Default figsize is set to (0.75 n_groups, 0.75 n_periods), and axes would be turned off by default.

Args:


property cmap

The global Colormap.

Returns:


property data

The preprocessed data.

Returns:


property norm

Data min-max normalizer.


source link

method add_child_plots

add_child_plots(
    child_plotter: 'Callable[[axes, Any], None]',
    child_args: 'dict[str, Any]',
    ax_kwargs: 'dict[str, Any]',
    on_empty: "Literal['hide', 'show']" = 'hide'
)None

Add custom child plots to the periodic table grid.

Args:


source link

method add_colorbar

add_colorbar(
    title: 'str',
    coords: 'tuple[float, float, float, float]' = (0.18, 0.8, 0.42, 0.02),
    cbar_kwargs: 'dict[str, Any] | None' = None,
    title_kwargs: 'dict[str, Any] | None' = None
)None

Add a global colorbar.

Args:


source link

method add_ele_symbols

add_ele_symbols(
    text: 'str | Callable[[Element], str]' = <function PTableProjector.<lambda> at 0x7f34a5276660>,
    pos: 'tuple[float, float]' = (0.5, 0.5),
    kwargs: 'dict[str, Any] | None' = None
)None

Add element symbols for each tile.

Args:


source link

class ChildPlotters

Collect some pre-defined child plotters.


source link

method line

line(ax: 'axes', data: 'SupportedValueType', **child_args: 'Any')None

Line plotter.

Args:


source link

method rectangle

rectangle(
    ax: 'axes',
    data: 'SupportedValueType',
    norm: 'Normalize',
    cmap: 'Colormap',
    start_angle: 'float'
)None

Rectangle heatmap plotter, could be evenly split.

Could be evenly split, depending on the length of the data (could mix and match).

Args:


source link

method scatter

scatter(ax: 'axes', data: 'SupportedValueType', **child_args: 'Any')None

Scatter plotter.

Args:

source link

module relevance

Plots for evaluating classifier performance.

Global Variables


source link

function roc_curve

roc_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
)tuple[float, Axes]

Plot the receiver operating characteristic curve of a binary classifier given target labels and predicted probabilities for the positive class.

Args:

Returns:


source link

function precision_recall_curve

precision_recall_curve(
    targets: 'ArrayLike | str',
    proba_pos: 'ArrayLike | str',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None
)tuple[float, Axes]

Plot the precision recall curve of a binary classifier.

Args:

Returns:

source link

module sankey

Sankey diagram for comparing distributions in two dataframe columns.

Global Variables


source link

function sankey_from_2_df_cols

sankey_from_2_df_cols(
    df: 'DataFrame',
    cols: 'list[str]',
    labels_with_counts: "bool | Literal['percent']" = True,
    **kwargs: 'Any'
) → Figure

Plot two columns of a dataframe as a Plotly Sankey diagram.

Args:

Raises:

Returns:

source link

module structure_viz

2D plots of pymatgen structures with matplotlib.

Global Variables


source link

function unit_cell_to_lines

unit_cell_to_lines(cell: 'ArrayLike')tuple[ArrayLike, ArrayLike, ArrayLike]

Convert lattice vectors to plot lines.

Args:

Returns: tuple[np.array, np.array, np.array]:


source link

function plot_structure_2d

plot_structure_2d(
    struct: 'Structure',
    ax: 'Axes | None' = None,
    rotation: 'str' = '10x,10y,0z',
    atomic_radii: 'float | dict[str, float] | None' = None,
    colors: 'dict[str, str | list[float]] | None' = None,
    scale: 'float' = 1,
    show_unit_cell: 'bool' = True,
    show_bonds: 'bool | NearNeighbors' = False,
    site_labels: "bool | Literal['symbol', 'species'] | dict[str, str | float] | Sequence[str | float]" = True,
    site_labels_bbox: 'dict[str, Any] | None' = None,
    label_kwargs: 'dict[str, Any] | None' = None,
    bond_kwargs: 'dict[str, Any] | None' = None,
    standardize_struct: 'bool | None' = None,
    axis: 'bool | str' = 'off'
) → Axes

Plot pymatgen structures in 2d with matplotlib.

Inspired by ASE’s ase.visualize.plot.plot_atoms() https://wiki.fysik.dtu.dk/ase/ase/visualize/visualize.html#matplotlib pymatviz aims to give similar output to ASE but supports disordered structures and avoids the conversion hassle of AseAtomsAdaptor.get_atoms(pmg_struct).

For example, these two snippets should give very similar output:

from pymatgen.ext.matproj import MPRester

mp_19017 = MPRester().get_structure_by_material_id("mp-19017")

# ASE
from ase.visualize.plot import plot_atoms
from pymatgen.io.ase import AseAtomsAdaptor

plot_atoms(AseAtomsAdaptor().get_atoms(mp_19017), rotation="10x,10y,0z", radii=0.5)

# pymatviz
from pymatviz import plot_structure_2d

plot_structure_2d(mp_19017)

Multiple structures in single figure example:

import matplotlib.pyplot as plt
from pymatgen.ext.matproj import MPRester
from pymatviz import plot_structure_2d

structures = [
     MPRester().get_structure_by_material_id(f"mp-{idx}") for idx in range(1, 5)
]
fig, axs = plt.subplots(2, 2, figsize=(12, 12))

for struct, ax in zip(structures, axs.flat):
     plot_structure_2d(struct, ax=ax)

Args:

Raises:

Returns:


source link

class ExperimentalWarning

Used for experimental show_bonds feature.

source link

module sunburst

Hierarchical multi-level pie charts (i.e. sunbursts).

E.g. for crystal symmetry distributions.

Global Variables


source link

function spacegroup_sunburst

spacegroup_sunburst(
    data: 'Sequence[int | str] | Series',
    show_counts: "Literal['value', 'percent', False]" = False,
    **kwargs: 'Any'
) → Figure

Generate a sunburst plot with crystal systems as the inner ring for a list of international space group numbers.

Hint: To hide very small labels, set a uniformtext minsize and mode=‘hide’. fig.update_layout(uniformtext=dict(minsize=9, mode=“hide”))

Args:

Returns:

source link

module uncertainty

Visualizations for assessing the quality of model uncertainty estimates.

Global Variables


source link

function qq_gaussian

qq_gaussian(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    ax: 'Axes | None' = None,
    identity_line: 'bool | dict[str, Any]' = True
) → Axes

Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array) or multiple (passed as dict) sets of uncertainty estimates for a single pair of ground truth targets y_true and model predictions y_pred.

Overconfidence relative to a Gaussian distribution is visualized as shaded areas below the parity line, underconfidence (oversized uncertainties) as shaded areas above the parity line.

The measure of calibration is how well the uncertainty percentiles conform to those of a normal distribution.

Inspired by https://git.io/JufOz. Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot

Args:

Returns:


source link

function get_err_decay

get_err_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    n_rand: 'int' = 100
)tuple[ArrayLike, ArrayLike]

Calculate the model’s error curve as samples are excluded from the calculation based on their absolute error.

Use in combination with get_std_decay to see what the error drop curve would look like if model error and uncertainty were perfectly rank-correlated.

Args:

Returns:


source link

function get_std_decay

get_std_decay(
    y_true: 'ArrayLike',
    y_pred: 'ArrayLike',
    y_std: 'ArrayLike'
) → ArrayLike

Calculate the drop in model error as samples are excluded from the calculation based on the model’s uncertainty.

For model’s able to estimate their own uncertainty well, meaning predictions of larger error are associated with larger uncertainty, the error curve should fall off sharply at first as the highest-error points are discarded and slowly towards the end where only small-error samples with little uncertainty remain.

Note that even perfect model uncertainties would not mean this error drop curve coincides exactly with the one returned by get_err_decay as in some cases the model may have made an accurate prediction purely by chance in which case the error is small yet a good uncertainty estimate would still be large, leading the same sample to be excluded at different x-axis locations and thus the get_std_decay curve to lie higher.

Args:

Returns:


source link

function error_decay_with_uncert

error_decay_with_uncert(
    y_true: 'ArrayLike | str',
    y_pred: 'ArrayLike | str',
    y_std: 'ArrayLike | dict[str, ArrayLike] | str | Sequence[str]',
    df: 'DataFrame | None' = None,
    n_rand: 'int' = 100,
    percentiles: 'bool' = True,
    ax: 'Axes | None' = None
) → Axes

Plot for assessing the quality of uncertainty estimates. If a model’s uncertainty is well calibrated, i.e. strongly correlated with its error, removing the most uncertain predictions should make the mean error decay similarly to how it decays when removing the predictions of largest error.

Args:

Note: If you’re not happy with the default y_max of 1.1 rand_mean, where rand_mean is mean of random sample exclusion, use ax.set(ylim=[None, some_value ax.get_ylim()[1]]).

Returns:

source link

module utils

pymatviz utility functions.

Global Variables


source link

function pretty_label

pretty_label(key: 'str', backend: 'Backend')str

Map metric keys to their pretty labels.


source link

function crystal_sys_from_spg_num

crystal_sys_from_spg_num(spg: 'float') → CrystalSystem

Get the crystal system for an international space group number.


source link

function df_to_arrays

df_to_arrays(
    df: 'DataFrame | None',
    *args: 'str | Sequence[str] | Sequence[ArrayLike]',
    strict: 'bool' = True
)list[ArrayLike | dict[str, ArrayLike]]

If df is None, this is a no-op: args are returned as-is. If df is a dataframe, all following args are used as column names and the column data returned as arrays (after dropping rows with NaNs in any column).

Args:

Raises:

Returns:


source link

function bin_df_cols

bin_df_cols(
    df_in: 'DataFrame',
    bin_by_cols: 'Sequence[str]',
    group_by_cols: 'Sequence[str]' = (),
    n_bins: 'int | Sequence[int]' = 100,
    bin_counts_col: 'str' = 'bin_counts',
    kde_col: 'str' = '',
    verbose: 'bool' = True
) → DataFrame

Bin columns of a DataFrame.

Args:

Returns:


source link

function patch_dict

patch_dict(
    dct: 'dict[Any, Any]',
    *args: 'Any',
    **kwargs: 'Any'
) → Generator[dict[Any, Any], None, None]

Context manager to temporarily patch the specified keys in a dictionary and restore it to its original state on context exit.

Useful e.g. for temporary plotly fig.layout mutations:

with patch_dict(fig.layout, showlegend=False): fig.write_image(“plot.pdf”)

Args:

Yields:


source link

function luminance

luminance(color: 'tuple[float, float, float]')float

Compute the luminance of a color as in https://stackoverflow.com/a/596243.

Args:

Returns:


source link

function pick_bw_for_contrast

pick_bw_for_contrast(
    color: 'tuple[float, float, float]',
    text_color_threshold: 'float' = 0.7
)str

Choose black or white text color for a given background color based on luminance.

Args:

Returns:


source link

function si_fmt

si_fmt(
    val: 'float',
    fmt: 'str' = '.1f',
    sep: 'str' = '',
    binary: 'bool' = False,
    decimal_threshold: 'float' = 0.01
)str

Convert large numbers into human readable format using SI prefixes.

Supports binary (1024) and metric (1000) mode.

https://nist.gov/pml/weights-and-measures/metric-si-prefixes

Args:

Returns:


source link

function styled_html_tag

styled_html_tag(text: 'str', tag: 'str' = 'span', style: 'str' = '')str

Wrap text in a span with custom style.

Style defaults to decreased font size and weight e.g. to display units in plotly labels and annotations.

Args:


source link

function validate_fig

validate_fig(func: 'Callable[P, R]') → Callable[P, R]

Decorator to validate the type of fig keyword argument in a function.


source link

function annotate

annotate(
    text: 'str',
    fig: 'AxOrFig | None' = None,
    color: 'str' = 'black',
    **kwargs: 'Any'
) → AxOrFig

Annotate a matplotlib or plotly figure.

Args:

Returns:


source link

function get_fig_xy_range

get_fig_xy_range(
    fig: 'Figure | Figure | Axes',
    trace_idx: 'int' = 0
)tuple[tuple[float, float], tuple[float, float]]

Get the x and y range of a plotly or matplotlib figure.

Args:

Returns: