Source code for darshan.experimental.plots.plot_dxt_heatmap

"""
Module for creating the ranks vs. time IO intensity
heatmap figure for the Darshan job summary.
"""

from __future__ import annotations
import functools
from typing import (Any, List, Sequence, Union,
                    TYPE_CHECKING, Tuple, Optional)

import numpy as np

if TYPE_CHECKING:
    import numpy.typing as npt

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import darshan
from darshan.experimental.plots import heatmap_handling


[docs]@functools.lru_cache(maxsize=10) def determine_hmap_runtime(report: darshan.DarshanReport) -> Tuple[float, float]: """ Determine the effective heatmap runtime to be used for plotting in cases where only DXT, only HEATMAP, or both module types are available, to achieve a common max displayed runtime. Parameters ---------- report: a ``darshan.DarshanReport`` Returns ------- A tuple containing `tmax`, `runtime` floats. """ # get the elapsed runtime runtime = report.metadata["job"]["run_time"] # leverage higher resolution DXT timing # data if available if ("DXT_POSIX" in report.records or "DXT_MPIIO" in report.records): tmax = 0.0 for mod in report.modules: if "DXT" in mod: agg_df = heatmap_handling.get_aggregate_data(report=report, mod=mod, ops=["read", "write"]) tmax_dxt = float(agg_df["end_time"].max()) if tmax_dxt > tmax: tmax = tmax_dxt else: tmax = runtime return tmax, runtime
[docs]def get_x_axis_ticks(bin_max: float, n_xlabels: int = 4) -> npt.NDArray[np.float64]: """ Creates the x-axis tick mark locations. Parameters ---------- bin_max: the maximum number of bins. n_xlabels: the number of x-axis tick marks to create. Default is 4. Returns ------- Array of x-axis tick mark locations of length ``n_xlabels``. """ return np.linspace(0, bin_max, n_xlabels)
[docs]def get_x_axis_tick_labels( max_time: float, n_xlabels: int = 4 ) -> Union[npt.NDArray[np.float64], npt.NDArray[np.intc]]: """ Creates the x-axis tick mark labels. Parameters ---------- max_time: the maximum time to plot. n_xlabels: the number of x-axis tick marks to create. Default is 4. Returns ------- x_ticklabels: array of x-axis tick mark labels of length ``n_xlabels``. """ # for the x tick labels, start at 0 and end with # the max time (converted to an integer) if max_time <= 1: # use 2 decimal places for run times less than 1 second x_ticklabels = np.around(np.linspace(0.0, max_time, n_xlabels), decimals=2) elif (max_time > 1) & (max_time <= 10): # use 1 decimal place for run times between 1 and 10 seconds x_ticklabels = np.around(np.linspace(0.0, max_time, n_xlabels), decimals=1) else: # for run times greater than 10 seconds, round the max # time up and round labels to the nearest integer x_ticklabels = np.linspace(0.0, np.ceil(max_time), n_xlabels, dtype=int) return x_ticklabels
[docs]def get_y_axis_ticks(ax: Any, n_ylabels: int = 6) -> npt.NDArray[np.float64]: """ Creates the y-axis tick mark locations. Parameters ---------- ax: a ``matplotlib`` axis object. n_ylabels: The number of y-axis tick mark labels to create. Default is 6. Returns ------- yticks: array of y-axis tick mark locations of length ``n_ylabels``. """ # get the original y-axis tick locations initial_yticks = ax.get_yticks() if len(initial_yticks) < n_ylabels: # if there are less tick marks available than requested, # use the original tick mark locations yticks = initial_yticks else: # use the original tick marks to make a new array that # contains a subset of the original tick marks tick_idx = np.round(np.linspace(0, initial_yticks.size - 1, n_ylabels)).astype( int ) yticks = initial_yticks[tick_idx] return yticks
[docs]def get_yticklabels(ax: Any) -> List[str]: """ Utility function for ``get_y_axis_tick_labels`` that retrieves the y-axis tick mark labels from the input axis. Parameters ---------- ax: a ``matplotlib`` axis object. Returns ------- y_ticklabels: list of y-axis tick mark labels of length ``n_ylabels``. """ # retrieve the original y-axis tick label strings from the axis object y_ticklabels = [tl.get_text() for tl in ax.get_yticklabels()] return y_ticklabels
[docs]def get_y_axis_tick_labels(ax: Any, n_ylabels: int = 6) -> npt.NDArray[np.intc]: """ Sets the y-axis tick mark labels. Parameters ---------- ax: a ``matplotlib`` axis object. n_ylabels: The number of y-axis tick mark labels to create. Default is 6. Returns ------- y_ticklabels: array of y-axis tick mark labels of length ``n_ylabels``. """ # get the original y-axis tick mark labels and convert them to an array initial_yticklabels = np.asarray(get_yticklabels(ax=ax)) if initial_yticklabels.size < n_ylabels: # if there are less tick marks available than requested, # use the original tick mark labels y_ticklabels = initial_yticklabels else: # use the original tick marks to make a new array that # contains a subset of the original tick mark labels tick_idx = np.round( np.linspace(0, initial_yticklabels.size - 1, n_ylabels) ).astype(int) y_ticklabels = initial_yticklabels[tick_idx] return y_ticklabels
[docs]def set_x_axis_ticks_and_labels( jointgrid: Any, tmax: float, bin_max: float, n_xlabels: int = 4 ): """ Sets the x-axis tick mark locations and labels. Parameters ---------- jointgrid: a ``sns.axisgrid.JointGrid`` object. tmax: the maximum time to plot. bin_max: the maximum number of bins. n_xlabels: the number of x-axis tick marks to create. Default is 4. """ # retrieve the x-axis tick mark locations and labels xticks = get_x_axis_ticks(bin_max=bin_max, n_xlabels=n_xlabels) xticklabels = get_x_axis_tick_labels(max_time=tmax, n_xlabels=n_xlabels) # set the x-axis ticks and labels jointgrid.ax_joint.set_xticks(xticks) jointgrid.ax_joint.set_xticklabels(xticklabels, minor=False)
[docs]def set_y_axis_ticks_and_labels(jointgrid: Any, n_ylabels: int = 6): """ Sets the y-axis tick mark locations and labels. Parameters ---------- jointgrid: a ``sns.axisgrid.JointGrid`` object. n_ylabels: The number of y-axis tick mark labels to create. Default is 6. """ # retrieve the y-axis tick mark locations and labels yticks = get_y_axis_ticks(ax=jointgrid.ax_joint, n_ylabels=n_ylabels) yticklabels = get_y_axis_tick_labels(ax=jointgrid.ax_joint, n_ylabels=n_ylabels) # set the new y-axis tick locations and labels jointgrid.ax_joint.set_yticks(yticks) jointgrid.ax_joint.set_yticklabels(yticklabels, minor=False)
[docs]def remove_marginal_graph_ticks_and_labels(marg_x: Any, marg_y: Any): """ Removes the frame, tick marks, and tick mark labels for the marginal bar graphs. Parameters ---------- marg_x : a x-axis marginal bar graph object. marg_y : a y-axis marginal bar graph object. """ for marg in [marg_x, marg_y]: # turn the frame off for both bar graphs marg.axis("off") # remove all tick mark labels for both bar graphs marg.tick_params( axis="x", bottom=False, labelbottom=False, top=False, labeltop=False ) marg.tick_params( axis="y", left=False, labelleft=False, right=False, labelright=False )
[docs]def adjust_for_colorbar(jointgrid: Any, fig_right: float, cbar_x0: float): """ Makes various subplot location adjustments such that a colorbar can fit in the overal figure panel. Parameters ---------- jointgrid: a ``sns.axisgrid.JointGrid`` object. fig_right: the location to set for the right side of the heatmap figure. cbar_x0: the x-axis location of the colorbar. """ # adjust the subplot so the x/y tick labels are legible jointgrid.fig.subplots_adjust( left=0.1, bottom=0.15, top=0.9, hspace=0.03, wspace=0.04 ) # set the location of the right side of the figure jointgrid.fig.subplots_adjust(right=fig_right) # get the positions of the joint and marginal x axes pos_joint_ax = jointgrid.ax_joint.get_position() pos_marg_x_ax = jointgrid.ax_marg_x.get_position() # set the position and dimensions of the joint plot such that it fills # the space as if there was no colorbar jointgrid.ax_joint.set_position( [pos_joint_ax.x0, pos_joint_ax.y0, pos_marg_x_ax.width, pos_joint_ax.height] ) # set the position of the colorbar such that it is on the # right side of the horizontal bar graph, and set its dimensions jointgrid.fig.axes[-1].set_position( [cbar_x0, pos_joint_ax.y0, 0.9, pos_joint_ax.height] )
[docs]def plot_heatmap( report: darshan.DarshanReport, mod: str = "DXT_POSIX", ops: Sequence[str] = ["read", "write"], xbins: int = 200, submodule: Optional[str] = None, ) -> Any: """ Creates a heatmap with marginal bar graphs and colorbar. Parameters ---------- report: a ``darshan.DarshanReport``. mod: the DXT module to do analysis for (i.e. "DXT_POSIX" or "DXT_MPIIO"). Default is ``"DXT_POSIX"``. ops: a sequence of keys designating which Darshan operations to use for data aggregation. Default is ``["read", "write"]``. xbins: the number of x-axis bins to create; it has no effect when `mod` is `HEATMAP` submodule: when `mod` is `HEATMAP` this specifies the source of the runtime heatmap data, otherwise it has no effect Returns ------- jgrid: a ``sns.axisgrid.JointGrid`` object containing a heat map of IO data, marginal bar graphs, and a colorbar. Raises ------ NotImplementedError: if a DXT module is not input (i.e. "DXT_POSIX"). ValueError: if the input module is not in the ``DarshanReport``. """ if "DXT" not in mod and "HEATMAP" not in mod: raise NotImplementedError("Only DXT and HEATMAP modules are supported.") if mod not in report.modules: raise ValueError(f"Module {mod} not found in DarshanReport.") nprocs = report.metadata["job"]["nprocs"] tmax, runtime = determine_hmap_runtime(report=report) if "DXT" in mod: # aggregate the data according to the selected modules and operations agg_df = heatmap_handling.get_aggregate_data(report=report, mod=mod, ops=ops) # get the heatmap data array # NOTE: the darshan runtime does not collect empty DXT records, # so we are not guaranteed to have data for all time spans # as a result, we force the upper time bound for the heatmap data # to be the wallclock time hmap_df = heatmap_handling.get_heatmap_df(agg_df=agg_df, xbins=xbins, nprocs=nprocs, max_time=runtime) elif mod == "HEATMAP": hmap_df = report.heatmaps[submodule].to_df(ops=ops) # mirror the DXT approach to heatmaps by # adding all-zero rows for inactive ranks hmap_df = hmap_df.reindex(index=range(nprocs), fill_value=0.0) xbins = hmap_df.shape[1] # build the joint plot with marginal histograms jgrid = sns.jointplot(kind="hist", bins=[xbins, nprocs], space=0.05) # clear the x and y axis marginal graphs jgrid.ax_marg_x.cla() jgrid.ax_marg_y.cla() # create the label for the colorbar colorbar_label = f"Data (B): {', '.join(ops)}" colorbar_kws = {"label": colorbar_label} # create the heatmap object using the heatmap data, # and assign it to the jointplot main figure hmap = sns.heatmap( hmap_df, ax=jgrid.ax_joint, # choose a color map that is not white at any value cmap="YlOrRd", norm=LogNorm(), cbar_kws=colorbar_kws, ) # add text for x-axis bin count xbin_label = f"Time bins: {xbins}" plt.text( x=1.03, y=-0.04, s=xbin_label, fontsize=9, verticalalignment="top", horizontalalignment="left", transform=jgrid.ax_joint.transAxes, ) # make the heatmap border visible for _, spine in hmap.spines.items(): spine.set_visible(True) # if there is more than 1 process, # create the horizontal bar graph if nprocs > 1: jgrid.ax_marg_y.barh( y=np.arange(nprocs), width=hmap_df.sum(axis=1), align="edge", facecolor="black", lw=0.5, ) else: # if there is only 1 rank turn the axis off jgrid.ax_marg_y.axis("off") # create the vertical bar graph jgrid.ax_marg_x.bar( x=np.arange(xbins), height=hmap_df.sum(axis=0), facecolor="black", align="edge", ) # scale the x-axis to span the calculated run time xbin_max = xbins * (runtime / tmax) jgrid.ax_joint.set_xlim(0.0, xbin_max) # set the x and y tick locations and labels using the runtime set_x_axis_ticks_and_labels(jointgrid=jgrid, tmax=runtime, bin_max=xbin_max, n_xlabels=4) set_y_axis_ticks_and_labels(jointgrid=jgrid, n_ylabels=6) # cleanup the marginal bar graph ticks and tick labels remove_marginal_graph_ticks_and_labels( marg_x=jgrid.ax_marg_x, marg_y=jgrid.ax_marg_y ) # set the dimensions of the figure to 6.5" wide x 4.5" tall jgrid.fig.set_size_inches(6.5, 4.5) if nprocs > 1: # if there are multiple ranks we want to move the colorbar on the far # right side of the horizontal bar graph adjust_for_colorbar(jointgrid=jgrid, fig_right=0.84, cbar_x0=0.85) else: # if there is only 1 unique rank there is no horizontal bar graph, # so set the subplot dimensions to fill the space adjust_for_colorbar(jointgrid=jgrid, fig_right=0.92, cbar_x0=0.82) # invert the y-axis so rank values are increasing jgrid.ax_joint.invert_yaxis() # set the axis labels jgrid.ax_joint.set_xlabel("Time (s)") jgrid.ax_joint.set_ylabel("Rank") plt.close() return jgrid