Source code for dynamo.plot.markers

import warnings

import numpy as np
import pandas as pd
from scipy.sparse import issparse

from ..configuration import _themes, reset_rcParams, set_figure_params
from ..tools.utils import get_mapper, update_dict
from .utils import save_fig


[docs]def bubble( adata, genes, group, gene_order=None, group_order=None, layer=None, theme=None, cmap=None, color_key=None, color_key_cmap="Spectral", background="white", pointsize=None, vmin=0, vmax=100, sym_c=False, alpha=0.8, edgecolor=None, linewidth=0, type="violin", sort="diagnoal", transpose=False, rotate_xlabel="horizontal", rotate_ylabel="horizontal", figsize=None, save_show_or_return="show", save_kwargs={}, **kwargs, ): """Bubble plots generalized to velocity, acceleration, curvature. It supports either the `dot` or `violin` plot mode. This function is loosely based on https://github.com/QuKunLab/COVID-19/blob/master/step3_plot_umap_and_marker_gene_expression.ipynb # add sorting Parameters ---------- adata: :class:`~anndata.AnnData` an Annodata object genes: `list` The gene list, i.e. marker gene or top acceleration, curvature genes, etc. group: `str` The column key in `adata.obs` that will be used to group cells. gene_order: `None` or `list` (default: `None`) The gene groups order that will show up in the resulting bubble plot. group_order: `None` or `list` (default: `None`) The cells groups order that will show up in the resulting bubble plot. layer: `None` or `str` (default: `None`) The layer of data to use for the bubble plot. theme: string (optional, default None) A color theme to use for plotting. A small set of predefined themes are provided which have relatively good aesthetics. Available themes are: * 'blue' * 'red' * 'green' * 'inferno' * 'fire' * 'viridis' * 'darkblue' * 'darkred' * 'darkgreen' cmap: string (optional, default 'Blues') The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed this will be used for shading points according to density (largely only of relevance for very large datasets). If values are passed this will be used for shading according the value. Note that if theme is passed then this value will be overridden by the corresponding option of the theme. color_key: dict or array, shape (n_categories) (optional, default None) A way to assign colors to categoricals. This can either be an explicit dict mapping labels to colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct category being provided in ``labels``. Either way this mapping will be used to color points according to the label. Note that if theme is passed then this value will be overridden by the corresponding option of the theme. color_key_cmap: string (optional, default 'Spectral') The name of a matplotlib colormap to use for categorical coloring. If an explicit ``color_key`` is not given a color mapping for categories can be generated from the label list and selecting a matching list of colors from the given colormap. Note that if theme is passed then this value will be overridden by the corresponding option of the theme. background: string or None (optional, default 'None`) The color of the background. Usually this will be either 'white' or 'black', but any color name will work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of the things that themes handle for you. Note that if theme is passed then this value will be overridden by the corresponding option of the theme. pointsize: `None` or `float` (default: None) The scale of the point size. Actual point cell size is calculated as `500.0 / np.sqrt(adata.shape[0]) * pointsize` vmin: `float` (default: `0`) The percentage of minimal value to consider. vmax: `float` (default: `100`) The percentage of maximal value to consider. sym_c: `bool` (default: `False`) Whether do you want to make the limits of continuous color to be symmetric, normally this should be used for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative values. alpha: `float` (default: `0.8`) alpha value of the plot edgecolor: `str` or `None` (default: `None`) The color of the edge of the dots when type is to be `dot`. linewidth: `str` or `None` (default: `None`) The width of the edge of the dots when type is to be `dot`. type: `str` (default: `violin`) The type of the bubble plot, one of `{'violin', 'dot'}`. figsize: `None` or `[float, float]` (default: None) The width and height of a figure. sort: `str` (default: `diagnol`) The method for sorting genes. Not implemented. Need to implement in 2021. transpose: `bool` (default: `False`) Whether to transpose the row/column of the resulting bubble plot. Gene and cell types are on x/y-axis by default. rotate_xlabel: `float` (default: `horizontal`) The angel to rotate the x-label. rotate_ylabel: `float` (default: `horizontal`) The angel to rotate the y-label. save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`) Whether to save, show or return the figure. save_kwargs: `dict` (default: `{}`) A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys according to your needs. kwargs: Additional arguments passed to plt.scatters or sns.violinplot. Returns ------- Nothing but plot the bubble plots. """ import matplotlib import matplotlib.pyplot as plt import seaborn as sns from matplotlib import rcParams from matplotlib.colors import to_hex if background is None: _background = rcParams.get("figure.facecolor") _background = to_hex(_background) if type(_background) is tuple else _background # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) else: _background = background # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) if theme is None: if _background in ["#ffffff", "black"]: _theme_ = "glasbey_dark" else: _theme_ = "glasbey_white" else: _theme_ = theme _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap if layer is None: mapper = get_mapper() has_splicing, has_labeling, splicing_labeling, has_protein = ( adata.uns["pp"]["has_splicing"], adata.uns["pp"]["has_labeling"], adata.uns["pp"]["splicing_labeling"], adata.uns["pp"]["has_protein"], ) if splicing_labeling: layer = mapper["X_total"] if mapper["X_total"] in adata.layers else "X_total" elif has_labeling: layer = mapper["X_total"] if mapper["X_total"] in adata.layers else "X_total" else: layer = mapper["X_spliced"] if mapper["X_spliced"] in adata.layers else "X_spliced" if group not in adata.obs_keys(): raise ValueError(f"argument group {group} is not a column name in `adata.obs`") genes = adata.var_names.intersection(set(genes)).to_list() if len(genes) == 0: raise ValueError(f"names from argument genes {genes} don't match any genes from `adata.var_names`.") # sort gene/cluster to update the orders uniq_groups = adata.obs[group].unique() if group_order is None: clusters = uniq_groups else: if not set(group_order).issubset(uniq_groups): raise ValueError( f"names from argument group_order {group_order} are not subsets of " f"`adata.obs[group].unique()`." ) clusters = group_order if gene_order is None: genes = genes else: if not set(gene_order).issubset(genes): raise ValueError( f"names from argument gene_order {gene_order} is not a subset of " f"`adata.var_names.intersection(set(genes)).to_list()`." ) genes = gene_order cells_df = adata.obs.get(group) gene_df = adata[:, genes].layers[layer] gene_df = gene_df.A if issparse(gene_df) else gene_df gene_df = pd.DataFrame(gene_df.T, index=genes, columns=adata.obs_names) xmin, xmax = gene_df.quantile(vmin / 100, axis=1), gene_df.quantile(vmax / 100, axis=1) if sym_c: _vmin, _vmax = np.zeros_like(xmin), np.zeros_like(xmax) i = 0 for a, b in zip(xmin, xmax): bounds = np.nanmax([np.abs(a), b]) bounds = bounds * np.array([-1, 1]) _vmin[i], _vmax[i] = bounds i += 1 xmin, xmax = _vmin, _vmax point_size = ( 16000.0 / np.sqrt(adata.shape[0]) if pointsize is None else 16000.0 / (len(genes) * len(clusters)) * pointsize ) if color_key is None: cmap_ = matplotlib.cm.get_cmap(color_key_cmap) cmap_.set_bad("lightgray") unique_labels = np.unique(clusters) num_labels = unique_labels.shape[0] color_key = plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels)) if figsize is None: width = 6 * len(genes) / 14 if transpose else 9 * len(genes) / 14 height = 4.5 * len(clusters) / 14 if transpose else 4.5 * len(genes) / 14 figsize = (height, width) if transpose else (width, height) else: figsize = figsize[::-1] if transpose else figsize # scatter_kwargs = dict( # alpha=0.8, s=point_size, edgecolor=None, linewidth=0, rasterized=False # ) # (0, 0, 0, 1) fig, axes = plt.subplots( len(genes) if transpose else 1, 1 if transpose else len(genes), figsize=figsize, facecolor=background, ) fig.subplots_adjust(hspace=0, wspace=0) clusters_vec = cells_df.loc[gene_df.columns.values].values # may also use clusters when transpose for igene, gene in enumerate(genes): cur_gene_df = pd.DataFrame({gene: gene_df.loc[gene, :].values, "clusters_": clusters_vec}) cur_gene_df = cur_gene_df.loc[cur_gene_df["clusters_"].isin(clusters)] if type == "violin": # use sort here sns.violinplot( data=cur_gene_df, x="clusters_" if transpose else gene, y=gene if transpose else "clusters_", orient="v" if transpose else "h", order=clusters, # genes if transpose else linewidth=None, palette=color_key, inner="box", scale="width", cut=0, ax=axes[igene], alpha=alpha, **kwargs, ) if transpose: axes[igene].set_ylim(xmin[igene], xmax[igene]) axes[igene].set_yticks([]) axes[igene].set_ylabel(gene, rotation=rotate_ylabel, ha="right", va="center") else: axes[igene].set_xlim(xmin[igene], xmax[igene]) axes[igene].set_xticks([]) axes[igene].set_xlabel(gene, rotation=rotate_xlabel, ha="right") elif type == "dot": # use sort here avg_perc_cluster = ( cur_gene_df.groupby("clusters_") .expression.apply(lambda x: pd.Series([x.mean(), (x != 0).sum() / len(x)])) .unstack() ) avg_perc_cluster.columns = ["avg", "perc"] axes[igene].scatter( x=clusters if transpose else gene, y=gene if transpose else clusters, s=avg_perc_cluster.loc[clusters, "perc"] * point_size, lw=2, c=avg_perc_cluster.loc[clusters, "avg"], cmap="viridis" if cmap is None else cmap, rasterized=False, edgecolor=edgecolor, linewidth=linewidth, alpha=alpha, ) if transpose: if igene != len(genes) - 1: axes[igene].set_xticks([]) else: axes[igene].set_xticklabels( list(map(str, np.array(clusters))), rotation=rotate_xlabel, ha="right", ) else: if igene != 0: axes[igene].set_yticks([]) else: axes[igene].set_yticklabels( list(map(str, np.array(clusters))), rotation=rotate_ylabel, ha="right", va="center", ) axes[igene].set_xlabel("") if transpose else axes[igene].set_ylabel("") if save_show_or_return == "save": s_kwargs = { "path": None, "prefix": "violin", "dpi": None, "ext": "pdf", "transparent": True, "close": True, "verbose": True, } s_kwargs = update_dict(s_kwargs, save_kwargs) save_fig(**s_kwargs) if background is not None: reset_rcParams() elif save_show_or_return == "show": with warnings.catch_warnings(): warnings.simplefilter("ignore") plt.tight_layout() plt.show() if background is not None: reset_rcParams() elif save_show_or_return == "return": if background is not None: reset_rcParams() return fig, axes