Source code for dynamo.plot.networks

from typing import Any, Dict, List, Optional, Tuple, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import networkx as nx
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib.axes import Axes

from ..tools.utils import flatten, index_gene, update_dict
from .utils import save_fig, save_show_ret, set_colorbar
from .utils_graph import ArcPlot


def nxvizPlot(
    adata: AnnData,
    cluster: str,
    cluster_name: str,
    edges_list: Dict[str, pd.DataFrame],
    plot: Literal["arcplot", "circosplot"] = "arcplot",
    network: Optional[nx.classes.digraph.DiGraph] = None,
    weight_scale: float = 5e3,
    weight_threshold: float = 1e-4,
    figsize: Tuple[float, float] = (6, 6),
    save_show_or_return: Literal["save", "show", "return"] = "show",
    save_kwargs: Dict[str, Any] = {},
    **kwargs,
) -> Optional[Any]:
    """Arc or circos plot of gene regulatory network for a particular cell cluster.

    Args:
        adata: an AnnData object.
        cluster: the group key that points to the column of `adata.obs`
        cluster_name: the group whose network and arcplot would be constructed and created.
        edges_list: a dictionary of dataframe of interactions between input genes for each group of cells based on
            ranking information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight`
            three columns.
        plot: which nxviz plot to use, one of 'arcplot' or 'circosplot'. Defaults to "arcplot".
        network: a direct network for this cluster constructed based on Jacobian analysis. Defaults to None.
        weight_scale: the factor by which the Jacobian matrix values are multiplied so that the edge could be displayer
            with proper width. Defaults to 5e3.
        weight_threshold: the threshold of weight that will be used to trim the edges for network reconstruction. Defaults to 1e-4.
        figsize: the size of each panel of the figure. Defaults to (6, 6).
        save_show_or_return: whether to save, show, or return the plotted figure. Defaults to "show".
        save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary
            and the save_fig function will use the {"path": None, "prefix": 'arcplot', "dpi": None, "ext": 'pdf',
            "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide
            a dictionary that properly modify those keys according to your needs. Defaults to {}.
        **kwargs: any other kwargs that would be passed to Arcplot or CircosPlot.

    Raises:
        ImportError: nxviz or networkx is not installed.
        ValueError: `weight_threshold` too high that no edge pass it.

    Returns:
        None would be returned by default. If `save_show_or_return` is set to be 'return', the generated `nxviz` plot
        object would be returned.
    """

    _, has_labeling = (
        adata.uns["pp"].get("has_splicing"),
        adata.uns["pp"].get("has_labeling"),
    )
    layer = "M_s" if not has_labeling else "M_t"
    if "layer" in kwargs.keys():
        layer = kwargs.pop("layer")

    import matplotlib.pyplot as plt

    try:
        import networkx as nx
        import nxviz as nv
    except ImportError:
        raise ImportError(
            "You need to install the packages `networkx, nxviz`."
            "install networkx via `pip install networkx`."
            "install nxviz via `pip install nxviz`."
        )

    if edges_list is not None:
        network = nx.from_pandas_edgelist(
            edges_list[cluster_name].query("weight > @weight_threshold"),
            "regulator",
            "target",
            edge_attr="weight",
            create_using=nx.DiGraph(),
        )
        if len(network.nodes) == 0:
            raise ValueError(
                f"weight_threshold is too high, no edge has weight than {weight_threshold} " f"for cluster {cluster}."
            )

    # Iterate over all the nodes in G, including the metadata
    if type(cluster_name) is str:
        cluster_names = [cluster_name]
    for n, d in network.nodes(data=True):
        # Calculate the degree of each node: G.node[n]['degree']
        network.nodes[n]["degree"] = nx.degree(network, n)
        # data has to be float
        if cluster is not None:
            network.nodes[n]["size"] = (
                adata[adata.obs[cluster].isin(cluster_names), n].layers[layer].A.mean().astype(float)
            )
        else:
            network.nodes[n]["size"] = adata[:, n].layers[layer].A.mean().astype(float)

        network.nodes[n]["label"] = n
    for e in network.edges():
        network.edges[e]["weight"] *= weight_scale

    if plot.lower() == "arcplot":
        prefix = "arcPlot"
        nv_ax = nv.arc(
            network,
            group_by=kwargs.pop("node_grouping", None),
            sort_by=kwargs.pop("sort_by", "degree"),
            node_color_by=kwargs.pop("node_color_by", "size"),
            node_alpha_by=kwargs.pop("node_alpha_by", None),
            node_size_by=kwargs.pop("node_size_by", None),
            edge_color_by=kwargs.pop("edge_color_by", None),
            edge_lw_by=kwargs.pop("edge_lw_by", "weight"),
            edge_alpha_by=kwargs.pop("edge_alpha_by", None),
            **kwargs.pop("layout_kwargs", {}),
        )
        # Create the customized ArcPlot object: a2, which only works for old version of nxviz
        # nv_ax = nv.ArcPlot(
        #     network,
        #     node_order=kwargs.pop("node_order", "degree"),
        #     node_size=kwargs.pop("node_size", None),
        #     node_grouping=kwargs.pop("node_grouping", None),
        #     group_order=kwargs.pop("group_order", "alphabetically"),
        #     node_color=kwargs.pop("node_color", "size"),
        #     node_labels=kwargs.pop("node_labels", True),
        #     edge_width=kwargs.pop("edge_width", "weight"),
        #     edge_color=kwargs.pop("edge_color", None),
        #     data_types=kwargs.pop("data_types", None),
        #     nodeprops=kwargs.pop(
        #         "nodeprops",
        #         {
        #             "facecolor": "None",
        #             "alpha": 0.2,
        #             "cmap": "viridis",
        #             "label": "label",
        #         },
        #     ),
        #     edgeprops=kwargs.pop("edgeprops", {"facecolor": "None", "alpha": 0.2}),
        #     node_label_color=kwargs.pop("node_label_color", False),
        #     group_label_position=kwargs.pop("group_label_position", None),
        #     group_label_color=kwargs.pop("group_label_color", False),
        #     fontsize=kwargs.pop("fontsize", 10),
        #     fontfamily=kwargs.pop("fontfamily", "serif"),
        #     figsize=figsize,
        # )
    elif plot.lower() == "circosplot":
        prefix = "circosPlot"
        nv_ax = nv.circos(
            network,
            group_by=kwargs.pop("node_grouping", None),
            sort_by=kwargs.pop("sort_by", "degree"),
            node_color_by=kwargs.pop("node_color_by", "size"),
            node_alpha_by=kwargs.pop("node_alpha_by", None),
            node_size_by=kwargs.pop("node_size_by", None),
            edge_color_by=kwargs.pop("edge_color_by", None),
            edge_lw_by=kwargs.pop("edge_lw_by", "weight"),
            edge_alpha_by=kwargs.pop("edge_alpha_by", None),
            **kwargs.pop("layout_kwargs", {}),
        )
        # Create the customized CircosPlot object: a2, which only works for old version of nxviz
        # nv_ax = nv.CircosPlot(
        #     network,
        #     node_order=kwargs.pop("node_order", "degree"),
        #     node_size=kwargs.pop("node_size", None),
        #     node_grouping=kwargs.pop("node_grouping", None),
        #     group_order=kwargs.pop("group_order", "alphabetically"),
        #     node_color=kwargs.pop("node_color", "size"),
        #     node_labels=kwargs.pop("node_labels", True),
        #     edge_width=kwargs.pop("edge_width", "weight"),
        #     edge_color=kwargs.pop("edge_color", None),
        #     data_types=kwargs.pop("data_types", None),
        #     nodeprops=kwargs.pop("nodeprops", None),
        #     node_label_layout="rotation",
        #     edgeprops=kwargs.pop("edgeprops", {"facecolor": "None", "alpha": 0.2}),
        #     node_label_color=kwargs.pop("node_label_color", False),
        #     group_label_position=kwargs.pop("group_label_position", None),
        #     group_label_color=kwargs.pop("group_label_color", False),
        #     fontsize=kwargs.pop("fontsize", 10),
        #     fontfamily=kwargs.pop("fontfamily", "serif"),
        #     figsize=figsize,
        # )

    # recover network edge weights
    for e in network.edges():
        network.edges[e]["weight"] /= weight_scale

    return save_show_ret(prefix, save_show_or_return, save_kwargs, nv_ax)


[docs]def arcPlot( adata: AnnData, cluster: str, cluster_name: str, edges_list: Optional[Dict[str, pd.DataFrame]] = None, network: Optional[nx.classes.DiGraph] = None, color: Optional[str] = None, cmap: "str" = "viridis", node_size: int = 100, cbar: bool = True, cbar_title: Optional[str] = None, figsize: Tuple[str, str] = (6, 6), save_show_or_return: Literal["save", "show", "return"] = "show", save_kwargs: Dict[str, Any] = {}, **kwargs, ) -> Optional[Any]: """Arc plot of gene regulatory network for a particular cell cluster. Args: adata: an AnnData object. cluster: the group key that points to the column of `adata.obs`. cluster_name: the group whose network and arcplot will be constructed and created. edges_list: a dictionary of dataframe of interactions between input genes for each group of cells based on ranking information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three columns. If None, the network should be provided directly. Defaults to None. network: a direct network for this cluster constructed based on Jacobian analysis. If None, `edges_list` must be provided to construct the network. Defaults to None. color: the layer key that will be used to retrieve average expression to color the node of each gene. If None, the nodes would not be colored. Defaults to None. cmap: the color map used for the ArcPlot. Defaults to "viridis". node_size: the size of the node, a constant. Defaults to 100. cbar: whether to display the colorbar when `color` is not None. Defaults to True. cbar_title: the titke of the colorbar when displayed. Defaults to None. figsize: the size of each panel of the figure. Defaults to (6, 6). save_show_or_return: whether to save, show, or return the plotted ArcPlot. Could be one of 'save', 'show', or 'return'. Defaults to "show". save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'arcplot', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. **kwargs: any other kwargs that would be passed to ArcPlot. Raises: ImportError: `networkx` not installed. Returns: None would be returned by default. If `save_show_or_return` is set to 'return', the generated ArcPlot object would be return. """ import matplotlib import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator try: import networkx as nx except ImportError: raise ImportError("You need to install the package `networkx`." "install networkx via `pip install networkx`.") if edges_list is not None: network = nx.from_pandas_edgelist( edges_list[cluster], "regulator", "target", edge_attr="weight", create_using=nx.DiGraph(), ) # Iterate over all the nodes in G, including the metadata if type(cluster_name) is str: cluster_names = [cluster_name] if type(color) is str and color in adata.layers.keys(): data = adata[adata.obs[cluster].isin(cluster_names), :].layers[color] color = [] for gene in network.nodes: c = np.mean(flatten(index_gene(adata, data, [gene]))) color.append(c) else: color = None fig, ax = plt.subplots(figsize=figsize) ap = ArcPlot(network=network, c=color, s=node_size, cmap=cmap, **kwargs) node_degree = [network.degree[i] for i in network.nodes] ap.draw(node_order=node_degree) if cbar and color is not None: norm = matplotlib.colors.Normalize(vmin=np.min(color), vmax=np.max(color)) mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) mappable.set_array(color) cb = plt.colorbar( mappable, cax=set_colorbar( ax, { "width": "12%", # width = 5% of parent_bbox width "height": "100%", # height : 50% "loc": "upper right", "bbox_to_anchor": (0.85, 0.85, 0.145, 0.17), "borderpad": 1.85, }, ), ax=ax, ) if cbar_title is not None: cb.ax.set_title(cbar_title) cb.set_alpha(1) cb._draw_all() cb.locator = MaxNLocator(nbins=3, integer=True) cb.update_ticks() if save_show_or_return in ["save", "both", "all"]: # Draw a to the screen plt.autoscale() s_kwargs = { "path": None, "prefix": "arcPlot", "dpi": None, "ext": "pdf", "transparent": True, "close": True, "verbose": True, } s_kwargs = update_dict(s_kwargs, save_kwargs) if save_show_or_return in ["both", "all"]: s_kwargs["close"] = False save_fig(**s_kwargs) if save_show_or_return in ["show", "both", "all"]: # Draw a to the screen plt.autoscale() # Display the plot plt.show() # plt.savefig('./unknown_arcplot.pdf', dpi=300) if save_show_or_return in ["return", "all"]: return ap
[docs]def circosPlot( network: nx.Graph, node_label_key: Optional[str] = None, circos_label_layout: str = "rotate", node_color_key: Optional[str] = None, show_colorbar: bool = True, edge_lw_scale: float = 0.5, edge_alpha_scale: float = 0.5, ) -> Axes: """wrapper for drawing circos plot via nxviz >= 0.7.3 Args: network : a network graph instance node_label_key : node label (attribute) in network for grouping nodes, by default None circos_label_layout : layout of circos plot (see nxviz docs for details), by default "rotate" node_color_key : node attribute in network, corresponding to color values of nodes, by default None show_colorbar : whether to show colorbar, by default True edge_lw_scale : the line width scale of edges drawn in plot edge_alpha_scale : the alpha (opacity, transparency) scale of edges, the value should be in [0, 1.0] Returns: the Matplotlib Axes on which the Circos plot is drawn. """ try: import nxviz as nv from nxviz import annotate except ImportError: raise ImportError("install nxviz via `pip install nxviz`.") ax = nv.circos( network, group_by=node_label_key, node_color_by=node_color_key, edge_lw_by="weight", edge_alpha_by="weight", edge_enc_kwargs={ "lw_scale": edge_lw_scale, "alpha_scale": edge_alpha_scale, }, ) annotate.circos_labels(network, group_by=node_label_key, layout=circos_label_layout) if node_color_key and show_colorbar: annotate.node_colormapping( network, color_by=node_color_key, legend_kwargs={"loc": "upper right", "bbox_to_anchor": (0.0, 1.0)}, ax=None, ) return ax
[docs]def circosPlotDeprecated( adata: AnnData, cluster: str, cluster_name: str, edges_list: Dict[str, pd.DataFrame], network: nx.classes.digraph.DiGraph = None, weight_scale: float = 5e3, weight_threshold: float = 1e-4, figsize: Tuple[float, float] = (12, 6), save_show_or_return: Literal["save", "show", "return"] = "show", save_kwargs: Dict[str, Any] = {}, **kwargs, ) -> Optional[Any]: """Deprecated. A wrapper of `dynamo.pl.networks.nxvizPlot` to plot Circos graph. See the `nxvizPlot` for more information. Returns: None would be returned by default. If `save_show_or_return` is set to be 'return', the generated `nxviz` plot object would be returned. """ nxvizPlot( adata, cluster, cluster_name, edges_list, plot="circosplot", network=network, weight_scale=weight_scale, weight_threshold=weight_threshold, figsize=figsize, save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, **kwargs, )
[docs]def hivePlot( adata: AnnData, edges_list: Dict[str, pd.DataFrame], cluster: str, cluster_names: Optional[List[str]] = None, weight_threshold: float = 1e-4, figsize: Tuple[float, float] = (6, 6), save_show_or_return: Literal["save", "show", "return"] = "show", save_kwargs: Dict[str, Any] = {}, ) -> Axes: """Hive plot of cell cluster specific gene regulatory networks. Args: adata: an AnnData object. edges_list: a dictionary of dataframe of interactions between input genes for each group of cells based on ranking information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three columns. cluster: the group key that points to the columns of `adata.obs`. cluster_names: the group whose network and arcplot will be constructed and created. Defaults to None. weight_threshold: the threshold of weight that will be used to trim the edges for network reconstruction. Defaults to 1e-4. figsize: the size of each panel of the figure. Defaults to (6, 6). save_show_or_return: whether to save, show, or return the figure. Could be one of "save", "show", or "return". Defaults to "show". save_kwargs: a dictionary that will be passed to the save_fig function. By default, it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'hiveplot', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your needs. Defaults to {}. Raises: ImportError: `networkx` or `hiveplotlib` not installed ValueError: invalid `edge_keys` ValueError: invalid `cluster_names` ValueError: `weight_threshold` too high to have any edge to pass it. Returns: None would be returned by default. If `save_show_or_return` is set to `return`, the matplotlib axes of the figure would be returned. """ # _, has_labeling = ( # adata.uns["pp"].get("has_splicing"), # adata.uns["pp"].get("has_labeling"), # ) # layer = "M_s" if not has_labeling else "M_t" # from matplotlib.lines import Line2D import matplotlib.pyplot as plt try: import networkx as nx from hiveplotlib import Axis, HivePlot, Node from hiveplotlib.viz import axes_viz, edge_viz, node_viz except ImportError: raise ImportError( "You need to install the package `networkx, hiveplotlib`." "install hiveplotlib via `pip install hiveplotlib`" "install networkx via `pip install nxviz`." ) reg_groups = list(adata.obs[cluster].unique()) if not set(edges_list.keys()).issubset(reg_groups): raise ValueError( f"the edges_list's keys are not equal or subset of the clusters from the " f"adata.obs[{cluster}]" ) if cluster_names is not None: reg_groups = list(set(reg_groups).intersection(cluster_names)) if len(reg_groups) == 0: raise ValueError( f"the clusters argument {cluster_names} provided doesn't match up with any clusters from the " f"adata." ) combined_edges, G, edges_dict = None, {}, {} for i, grp in enumerate(edges_list.keys()): G[grp] = nx.from_pandas_edgelist( edges_list[grp].query("weight > @weight_threshold"), "regulator", "target", edge_attr="weight", create_using=nx.DiGraph(), ) if len(G[grp].nodes) == 0: raise ValueError( f"weight_threshold is too high, no edge has weight than {weight_threshold} " f"for cluster {grp}." ) edges_dict[grp] = np.array(G[grp].edges) combined_edges = edges_list[grp] if combined_edges is None else pd.concat((combined_edges, edges_list[grp])) # pull out degree information from nodes for later use combined_G = nx.from_pandas_edgelist( combined_edges.query("weight > @weight_threshold"), "regulator", "target", edge_attr="weight", create_using=nx.DiGraph(), ) edges = np.array(combined_G.edges) node_ids, degrees = np.unique(edges, return_counts=True) nodes = [] for node_id, degree in zip(node_ids, degrees): # store the index number as a way to align the nodes on axes combined_G.nodes.data()[node_id]["loc"] = node_id # also store the degree of each node as another way to # align nodes on axes combined_G.nodes.data()[node_id]["degree"] = degree temp_node = Node(unique_id=node_id, data=combined_G.nodes.data()[node_id]) nodes.append(temp_node) hp = HivePlot() # nodes ### hp.add_nodes(nodes) # axes ### angles = np.linspace(0, 360, len(reg_groups) + 1) axes = [] for i, grp in enumerate(reg_groups): axis = Axis(axis_id=grp, start=1, end=5, angle=angles[i], long_name=grp) axes.append(axis) hp.add_axes(axes) # node assignments ### nodes = [node.unique_id for node in nodes] # assign nodes and sorting procedure to position nodes on axis for i, grp in enumerate(reg_groups): hp.place_nodes_on_axis(axis_id=grp, unique_ids=nodes, sorting_feature_to_use="degree") for i, grp in enumerate(reg_groups): # edges ### nex_grp = reg_groups[i + 1] if i < len(reg_groups) - 1 else reg_groups[0] hp.connect_axes( edges=edges_dict[grp], axis_id_1=grp, axis_id_2=nex_grp, c="C" + str(i), ) # different color for each lineage # plot axes fig, ax = axes_viz(hp, figsize=figsize, axes_labels_buffer=1.4) # plot nodes node_viz(hp, fig=fig, ax=ax, s=80, c="black") # plot edges edge_viz(hive_plot=hp, fig=fig, ax=ax, alpha=0.7, zorder=-1) # ax.set_title("Hive Plot", fontsize=20, y=0.9) # custom_lines = [Line2D([0], [0], color=f'C{i}', lw=3, linestyle='-') for i in range(len(reg_groups))] # ax.legend(custom_lines, reg_groups, loc='upper left', bbox_to_anchor=(0.37, 0.35), # title="Regulatory network based on Jacobian analysis") return save_show_ret("hiveplot", save_show_or_return, save_kwargs, ax)