Source code for dynamo.plot.state_graph

import numpy as np

from ..tools.utils import update_dict
from .scatters import scatters, docstrings
from .utils import save_fig

docstrings.delete_params("scatters.parameters", "aggregate", "kwargs", "save_kwargs")


def create_edge_patch(
    posA,
    posB,
    width=1,
    node_rad=0,
    connectionstyle="arc3, rad=0.25",
    facecolor="k",
    **kwargs
):
    import matplotlib.patches as pat

    style = "simple,head_length=%d,head_width=%d,tail_width=%d" % (10, 10, 3 * width)
    return pat.FancyArrowPatch(
        posA=posA,
        posB=posB,
        arrowstyle=style,
        connectionstyle=connectionstyle,
        facecolor=facecolor,
        shrinkA=node_rad,
        shrinkB=node_rad,
        **kwargs
    )


def create_edge_patches_from_markov_chain(
    P,
    X,
    width=3,
    node_rad=0,
    tol=1e-7,
    connectionstyle="arc3, rad=0.25",
    facecolor="k",
    alpha=0.8,
    **kwargs
):
    arrows = []
    for i in range(P.shape[0]):
        for j in range(P.shape[0]):
            if P[i, j] > tol:
                arrows.append(
                    create_edge_patch(
                        X[i],
                        X[j],
                        width=P[i, j] * width,
                        node_rad=node_rad,
                        connectionstyle=connectionstyle,
                        facecolor=facecolor,
                        alpha=alpha * min(2 * P[i, j], 1),
                        **kwargs
                    )
                )
    return arrows


[docs]@docstrings.with_indent(4) def state_graph( adata, group, basis="umap", x=0, y=1, color='ntr', layer="X", highlights=None, labels=None, values=None, theme=None, cmap=None, color_key=None, color_key_cmap=None, background=None, ncols=1, pointsize=None, figsize=(6, 4), show_legend=True, use_smoothed=True, show_arrowed_spines=True, ax=None, sort='raw', frontier=False, save_show_or_return="show", save_kwargs={}, s_kwargs_dict={}, **kwargs ): """Plot a summarized cell type (state) transition graph. This function tries to create a model that summarizes the possible cell type transitions based on the reconstructed vector field function. Parameters ---------- group: `str` or `None` (default: `None`) The column in adata.obs that will be used to aggregate data points for the purpose of creating a cell type transition model. %(scatters.parameters.no_aggregate|kwargs|save_kwargs)s save_kwargs: `dict` (default: `{}`) A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys according to your needs. s_kwargs_dict: `dict` (default: {}) The dictionary of the scatter arguments. Returns ------- Plot the a model of cell fate transition that summarizes the possible lineage commitments between different cell types. """ import matplotlib.pyplot as plt from matplotlib import rcParams from matplotlib.colors import to_hex aggregate = group points = adata.obsm["X_" + basis][:, [x, y]] groups, uniq_grp = adata.obs[group], adata.obs[group].unique().to_list() group_median = np.zeros((len(uniq_grp), 2)) grp_size = adata.obs[group].value_counts().values s_kwargs_dict.update({"s": grp_size}) Pl = adata.uns["Cell type annotation_graph"]["group_graph"] Pl[Pl - Pl.T < 0] = 0 Pl /= Pl.sum(1)[:, None] for i, cur_grp in enumerate(uniq_grp): group_median[i, :] = np.nanmedian(points[np.where(groups == cur_grp)[0], :2], 0) if background is None: _background = rcParams.get("figure.facecolor") background = to_hex(_background) if type(_background) is tuple else _background plt.figure(facecolor=_background) axes_list, color_list, font_color = scatters( adata=adata, basis=basis, x=x, y=y, color=color, layer=layer, highlights=highlights, labels=labels, values=values, theme=theme, cmap=cmap, color_key=color_key, color_key_cmap=color_key_cmap, background=background, ncols=ncols, pointsize=pointsize, figsize=figsize, show_legend=show_legend, use_smoothed=use_smoothed, aggregate=aggregate, show_arrowed_spines=show_arrowed_spines, ax=ax, sort=sort, save_show_or_return='return', frontier=frontier, **s_kwargs_dict, return_all=True, ) arrows = create_edge_patches_from_markov_chain( Pl, group_median, tol=0.01, node_rad=15 ) if type(axes_list) == list: for i in range(len(axes_list)): for arrow in arrows: axes_list[i].add_patch(arrow) axes_list[i].set_facecolor(background) else: for arrow in arrows: axes_list.add_patch(arrow) axes_list.set_facecolor(background) plt.axis("off") plt.show() if save_show_or_return == "save": s_kwargs = {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} s_kwargs = update_dict(s_kwargs, save_kwargs) save_fig(**s_kwargs) elif save_show_or_return == "show": if show_legend: plt.subplots_adjust(right=0.85) plt.tight_layout() plt.show() elif save_show_or_return == "return": return axes_list, color_list, font_color