from typing import Any, Dict, List, Optional, Tuple, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib.axes import Axes
from ..tools.utils import update_dict
from .scatters import docstrings, scatters
from .utils import save_fig
docstrings.delete_params("scatters.parameters", "aggregate", "kwargs", "save_kwargs")
def create_edge_patch(posA, posB, width=1, node_rad=0, connectionstyle="arc3, rad=0.25", facecolor="k", **kwargs):
import matplotlib.patches as pat
style = "simple,head_length=%d,head_width=%d,tail_width=%d" % (
10,
10,
3 * width,
)
return pat.FancyArrowPatch(
posA=posA,
posB=posB,
arrowstyle=style,
connectionstyle=connectionstyle,
facecolor=facecolor,
shrinkA=node_rad,
shrinkB=node_rad,
**kwargs,
)
def create_edge_patches_from_markov_chain(
P,
X,
width=3,
node_rad=0,
tol=1e-7,
connectionstyle="arc3, rad=0.25",
facecolor="k",
edgecolor="k",
alpha=0.8,
**kwargs
):
"""
create edge patches from a markov chain transition matrix. If P[i, j] > tol, an arrow is created from
node i to j.
"""
arrows = []
for i in range(P.shape[0]):
for j in range(P.shape[0]):
if P[i, j] > tol:
if type(facecolor) == str:
fc = facecolor
else:
if type(facecolor) == pd.DataFrame:
fc = facecolor.iloc[i, j]
else:
fc = facecolor[i, j]
if type(edgecolor) == str:
ec = edgecolor
else:
if type(edgecolor) == pd.DataFrame:
ec = edgecolor.iloc[i, j]
else:
ec = edgecolor[i, j]
if type(alpha) == float:
ac = alpha * min(2 * P[i, j], 1)
else:
if type(alpha) == pd.DataFrame:
ac = alpha.iloc[i, j]
else:
ac = alpha[i, j]
arrows.append(
create_edge_patch(
X[i],
X[j],
width=P[i, j] * width,
node_rad=node_rad,
connectionstyle=connectionstyle,
facecolor=fc,
edgecolor=ec,
alpha=ac,
**kwargs,
)
)
return arrows
[docs]@docstrings.with_indent(4)
def state_graph(
adata: AnnData,
group: Optional[str] = None,
transition_threshold: float = 0.001,
keep_only_one_direction: bool = True,
edge_scale: float = 1,
state_graph: Optional[np.ndarray] = None,
edgecolor: Union[None, np.ndarray, pd.DataFrame] = None,
facecolor: Union[None, np.ndarray, pd.DataFrame] = None,
graph_alpha: Union[None, np.ndarray, pd.DataFrame] = None,
basis: str = "umap",
x: int = 0,
y: int = 1,
color: str = "ntr",
layer: str = "X",
highlights: Optional[list] = None,
labels: Optional[list] = None,
values: Optional[list] = None,
theme: Optional[
Literal[
"blue",
"red",
"green",
"inferno",
"fire",
"viridis",
"darkblue",
"darkred",
"darkgreen",
]
] = None,
cmap: Optional[str] = None,
color_key: Union[Dict[str, str], List[str], None] = None,
color_key_cmap: Optional[str] = None,
background: Optional[str] = None,
ncols: int = 4,
pointsize: Optional[float] = None,
figsize: Tuple[float, float] = (6, 4),
show_legend: bool = True,
use_smoothed: bool = True,
show_arrowed_spines: bool = False,
ax: Optional[Axes] = None,
sort: Literal["raw", "abs", "neg"] = "raw",
frontier: bool = False,
save_show_or_return: Literal["save", "show", "return"] = "show",
save_kwargs: Dict[str, Any] = {},
s_kwargs_dict: Dict[str, Any] = {"alpha": 1},
**kwargs
) -> Union[
Tuple[Axes, List[str], Literal["white", "black"]],
Tuple[List[Axes], List[str], Literal["white", "black"]],
None,
]:
"""Plot a summarized cell type (state) transition graph. This function tries to create a model that summarizes the
possible cell type transitions based on the reconstructed vector field function.
Args:
adata: an AnnData object.
group: the column in adata.obs that will be used to aggregate data points for the purpose of creating a cell
type transition model. Defaults to None.
transition_threshold: the threshold of cell fate transition. Transition will be ignored if below this threshold.
Defaults to 0.001.
keep_only_one_direction: whether to only keep the higher transition between two cell type. That is if the
transition rate from A to B is higher than B to A, only edge from A to B will be plotted. Defaults to True.
edge_scale: the scaler that can be used to scale the edge width of drawn transition graph. Defaults to 1.
state_graph: the lumped transition graph between cell states (e.g. cell clusters or types). Defaults to None.
edgecolor: the edge color of the arcs that corresponds to the lumped transition graph between cell states.
Defaults to None.
facecolor: the edge color of the arcs that corresponds to the lumped transition graph between cell states.
Defaults to None.
graph_alpha: the alpha of the arcs that corresponds to the lumped transition graph between cell states. Defaults
to None.
basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following
priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters`
is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`.
Defaults to "umap".
x: the column index of the low dimensional embedding for the x-axis. Defaults to 0.
y: the column index of the low dimensional embedding for the y-axis. Defaults to 1.
color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr".
layer: the layer of data to use for the scatter plot. Defaults to "X".
highlights: the color group that will be highlighted. If highligts is a list of lists, each list is relate to
each color element. Defaults to None.
labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for
coloring the points in the plot according to their label. Note that this option is mutually exclusive to the
`values` option. Defaults to None.
values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring
the points in the plot according to a colorscale associated to the total range of values. Note that this
option is mutually exclusive to the `labels` option. Defaults to None.
theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively
good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue',
'darkred', 'darkgreen'}. Defaults to None.
cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed
this will be used for shading points according to density (largely only of relevance for very large
datasets). If values are passed this will be used for shading according the value. Note that if theme is
passed then this value will be overridden by the corresponding option of the theme. Defaults to None.
color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to
colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct
category being provided in `labels`. Either way this mapping will be used to color points according to the
label. Note that if theme is passed then this value will be overridden by the corresponding option of the
theme. Defaults to None.
color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is
not given a color mapping for categories can be generated from the label list and selecting a matching list
of colors from the given colormap. Note that if theme is passed then this value will be overridden by the
corresponding option of the theme. Defaults to None.
background: the color of the background. Usually this will be either 'white' or 'black', but any color name will
work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of
the things that themes handle for you. Note that if theme is passed then this value will be overridden by
the corresponding option of the theme. Defaults to None.
ncols: the number of columns for the figure. Defaults to 4.
pointsize: the scale of the point size. Actual point cell size is calculated as
`500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None.
figsize: the width and height of a figure. Defaults to (6, 4).
show_legend: whether to display a legend of the labels. Defaults to "on data".
use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to
True.
show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently
using. Defaults to False.
ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component.
Defaults to None.
sort: the method to reorder data so that high values points will be on top of background points. Can be one of
{'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults
to "raw".
frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to
show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib
tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from
scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True,
`frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False.
save_show_or_return: whether to save, show, or return the generated figure. Defaults to "show".
save_kwargs: A dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
and the save_fig function will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf',
"transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a
dictionary that properly modify those keys according to your needs. Defaults to {}.
s_kwargs_dict: any other kwargs that would be passed to `dynamo.pl.scatters`. Defaults to {"alpha": 1}.
Returns:
None would be returned by default. If `save_show_or_return` is set to be `return`, the matplotlib axes
object of the generated plots, the list of colors used and the font color would be returned.
"""
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.colors import to_hex
aggregate = group
points = adata.obsm["X_" + basis][:, [x, y]]
unique_group_obs = adata.obs[group].unique()
if type(unique_group_obs) is np.ndarray:
groups, uniq_grp = adata.obs[group], unique_group_obs.tolist()
elif type(unique_group_obs) is pd.Series:
groups, uniq_grp = adata.obs[group], unique_group_obs.to_list()
else:
groups, uniq_grp = adata.obs[group], list(unique_group_obs)
group_median = np.zeros((len(uniq_grp), 2))
# grp_size = adata.obs[group].value_counts()[uniq_grp].values
# s_kwargs_dict.update({"s": grp_size})
if state_graph is None:
Pl = adata.uns[group + "_graph"]["group_graph"]
if keep_only_one_direction:
Pl[Pl - Pl.T < 0] = 0
if transition_threshold is not None:
Pl[Pl < transition_threshold] = 0
Pl /= Pl.sum(1)[:, None] * edge_scale
else:
Pl = state_graph
for i, cur_grp in enumerate(uniq_grp):
group_median[i, :] = np.nanmedian(points[np.where(groups == cur_grp)[0], :2], 0)
if background is None:
_background = rcParams.get("figure.facecolor")
background = to_hex(_background) if type(_background) is tuple else _background
plt.figure(facecolor=_background)
axes_list, color_list, font_color = scatters(
adata=adata,
basis=basis,
x=x,
y=y,
color=color,
layer=layer,
highlights=highlights,
labels=labels,
values=values,
theme=theme,
cmap=cmap,
color_key=color_key,
color_key_cmap=color_key_cmap,
background=background,
ncols=ncols,
pointsize=pointsize,
figsize=figsize,
show_legend=show_legend,
use_smoothed=use_smoothed,
aggregate=aggregate,
show_arrowed_spines=show_arrowed_spines,
ax=ax,
sort=sort,
save_show_or_return="return",
frontier=frontier,
**s_kwargs_dict,
return_all=True,
)
edgecolor = "k" if edgecolor is None else edgecolor
facecolor = "k" if facecolor is None else facecolor
graph_alpha = 0.8 if graph_alpha is None else graph_alpha
arrows = create_edge_patches_from_markov_chain(
Pl, group_median, edgecolor=edgecolor, facecolor=facecolor, alpha=graph_alpha, tol=0.01, node_rad=15
)
if type(axes_list) == list:
for i in range(len(axes_list)):
for arrow in arrows:
axes_list[i].add_patch(arrow)
axes_list[i].set_facecolor(background)
else:
for arrow in arrows:
axes_list.add_patch(arrow)
axes_list.set_facecolor(background)
plt.axis("off")
if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "state_graph",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)
if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False
save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
if show_legend:
plt.subplots_adjust(right=0.85)
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return axes_list, color_list, font_color