import networkx as nx
import numpy as np
import nxviz as nv
import nxviz.annotate
import pandas as pd
from matplotlib.axes import Axes
from ..tools.utils import flatten, index_gene, update_dict
from .utils import save_fig, set_colorbar
from .utils_graph import ArcPlot
def nxvizPlot(
adata,
cluster,
cluster_name,
edges_list,
plot="arcplot",
network=None,
weight_scale=5e3,
weight_threshold=1e-4,
figsize=(6, 6),
save_show_or_return="show",
save_kwargs={},
**kwargs,
):
"""Arc or circos plot of gene regulatory network for a particular cell cluster.
Parameters
----------
adata: :class:`~anndata.AnnData`.
AnnData object.
cluster: `str`
The group key that points to the columns of `adata.obs`.
cluster_name: `str` (default: `None`)
The group whose network and arcplot will be constructed and created.
edges_list: `dict` of `pandas.DataFrame`
A dictionary of dataframe of interactions between input genes for each group of cells based on ranking
information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three
columns.
plot: `str` (default: `arcplot`)
Which nxviz plot to use, one of {'arcplot', 'circosplot'}.
network: class:`~networkx.classes.digraph.DiGraph`
A direct network for this cluster constructed based on Jacobian analysis.
weight_scale: `float` (default: `1e3`)
Because values in Jacobian matrix is often small, the value will be multiplied by the weight_scale so that
the edge will have proper width in display.
weight_threshold: `float` (default: `weight_threshold`)
The threshold of weight that will be used to trim the edges for network reconstruction.
figsize: `None` or `[float, float]` (default: (6, 6)
The width and height of each panel in the figure.
save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
Whether to save, show or return the figure.
save_kwargs: `dict` (default: `{}`)
A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'arcplot', "dpi": None, "ext": 'pdf', "transparent":
True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that
properly modify those keys according to your needs.
**kwargs:
Additional parameters that will pass to ArcPlot or CircosPlot
Returns
-------
Nothing but plot an ArcPlot of the input direct network.
"""
_, has_labeling = (
adata.uns["pp"].get("has_splicing"),
adata.uns["pp"].get("has_labeling"),
)
layer = "M_s" if not has_labeling else "M_t"
if "layer" in kwargs.keys():
layer = kwargs.pop("layer")
import matplotlib.pyplot as plt
try:
import networkx as nx
import nxviz as nv
except ImportError:
raise ImportError(
"You need to install the packages `networkx, nxviz`."
"install networkx via `pip install networkx`."
"install nxviz via `pip install nxviz`."
)
if edges_list is not None:
network = nx.from_pandas_edgelist(
edges_list[cluster_name].query("weight > @weight_threshold"),
"regulator",
"target",
edge_attr="weight",
create_using=nx.DiGraph(),
)
if len(network.node) == 0:
raise ValueError(
f"weight_threshold is too high, no edge has weight than {weight_threshold} " f"for cluster {cluster}."
)
# Iterate over all the nodes in G, including the metadata
if type(cluster_name) is str:
cluster_names = [cluster_name]
for n, d in network.nodes(data=True):
# Calculate the degree of each node: G.node[n]['degree']
network.nodes[n]["degree"] = nx.degree(network, n)
# data has to be float
if cluster is not None:
network.nodes[n]["size"] = (
adata[adata.obs[cluster].isin(cluster_names), n].layers[layer].A.mean().astype(float)
)
else:
network.nodes[n]["size"] = adata[:, n].layers[layer].A.mean().astype(float)
network.nodes[n]["label"] = n
for e in network.edges():
network.edges[e]["weight"] *= weight_scale
if plot.lower() == "arcplot":
prefix = "arcPlot"
# Create the customized ArcPlot object: a2
nv_ax = nv.ArcPlot(
network,
node_order=kwargs.pop("node_order", "degree"),
node_size=kwargs.pop("node_size", None),
node_grouping=kwargs.pop("node_grouping", None),
group_order=kwargs.pop("group_order", "alphabetically"),
node_color=kwargs.pop("node_color", "size"),
node_labels=kwargs.pop("node_labels", True),
edge_width=kwargs.pop("edge_width", "weight"),
edge_color=kwargs.pop("edge_color", None),
data_types=kwargs.pop("data_types", None),
nodeprops=kwargs.pop(
"nodeprops",
{
"facecolor": "None",
"alpha": 0.2,
"cmap": "viridis",
"label": "label",
},
),
edgeprops=kwargs.pop("edgeprops", {"facecolor": "None", "alpha": 0.2}),
node_label_color=kwargs.pop("node_label_color", False),
group_label_position=kwargs.pop("group_label_position", None),
group_label_color=kwargs.pop("group_label_color", False),
fontsize=kwargs.pop("fontsize", 10),
fontfamily=kwargs.pop("fontfamily", "serif"),
figsize=figsize,
)
elif plot.lower() == "circosplot":
prefix = "circosPlot"
# Create the customized CircosPlot object: a2
nv_ax = nv.CircosPlot(
network,
node_order=kwargs.pop("node_order", "degree"),
node_size=kwargs.pop("node_size", None),
node_grouping=kwargs.pop("node_grouping", None),
group_order=kwargs.pop("group_order", "alphabetically"),
node_color=kwargs.pop("node_color", "size"),
node_labels=kwargs.pop("node_labels", True),
edge_width=kwargs.pop("edge_width", "weight"),
edge_color=kwargs.pop("edge_color", None),
data_types=kwargs.pop("data_types", None),
nodeprops=kwargs.pop("nodeprops", None),
node_label_layout="rotation",
edgeprops=kwargs.pop("edgeprops", {"facecolor": "None", "alpha": 0.2}),
node_label_color=kwargs.pop("node_label_color", False),
group_label_position=kwargs.pop("group_label_position", None),
group_label_color=kwargs.pop("group_label_color", False),
fontsize=kwargs.pop("fontsize", 10),
fontfamily=kwargs.pop("fontfamily", "serif"),
figsize=figsize,
)
# recover network edge weights
for e in network.edges():
network.edges[e]["weight"] /= weight_scale
if save_show_or_return == "save":
# Draw a to the screen
nv_ax.draw()
plt.autoscale()
s_kwargs = {
"path": None,
"prefix": prefix,
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)
save_fig(**s_kwargs)
elif save_show_or_return == "show":
# Draw a to the screen
nv_ax.draw()
plt.autoscale()
# Display the plot
plt.show()
# plt.savefig('./unknown_arcplot.pdf', dpi=300)
elif save_show_or_return == "return":
return nv_ax
[docs]def arcPlot(
adata,
cluster,
cluster_name,
edges_list=None,
network=None,
color=None,
cmap="viridis",
node_size=100,
cbar=True,
cbar_title=None,
figsize=(6, 6),
save_show_or_return="show",
save_kwargs={},
**kwargs,
):
"""Arc plot of gene regulatory network for a particular cell cluster.
Parameters
----------
adata: :class:`~anndata.AnnData`.
AnnData object.
cluster: `str`
The group key that points to the columns of `adata.obs`.
cluster_name: `str` (default: `None`)
The group whose network and arcplot will be constructed and created.
edges_list: `dict` of `pandas.DataFrame`
A dictionary of dataframe of interactions between input genes for each group of cells based on ranking
information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three
columns.
network: class:`~networkx.classes.digraph.DiGraph`
A direct network for this cluster constructed based on Jacobian analysis.
color: `str` or None (default: `None`)
The layer key that will be used to retrieve average expression to color the node of each gene.
node_size: `float` (default: `100`)
The size of the node, a constant.
cbar: `bool` (default: `True`)
Whether or not to display colorbar when `color` is not None.
cbar_title: `float` (default: `weight_threshold`)
The title of the color bar when displayed.
figsize: `None` or `[float, float]` (default: (6, 6)
The width and height of each panel in the figure.
save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
Whether to save, show or return the figure.
save_kwargs: `dict` (default: `{}`)
A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'arcplot', "dpi": None, "ext": 'pdf', "transparent":
True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that
properly modify those keys according to your needs.
**kwargs:
Additional parameters that will eventually pass to ArcPlot.
Returns
-------
Nothing but plot an ArcPlot of the input direct network.
"""
"""nxvizPlot(adata,
cluster,
cluster_name,
edges_list,
plot='arcplot',
network=network,
weight_scale=weight_scale,
figsize=figsize,
save_show_or_return=save_show_or_return,
save_kwargs=save_kwargs,
**kwargs,
)"""
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
try:
import networkx as nx
except ImportError:
raise ImportError("You need to install the package `networkx`." "install networkx via `pip install networkx`.")
if edges_list is not None:
network = nx.from_pandas_edgelist(
edges_list[cluster],
"regulator",
"target",
edge_attr="weight",
create_using=nx.DiGraph(),
)
# Iterate over all the nodes in G, including the metadata
if type(cluster_name) is str:
cluster_names = [cluster_name]
if type(color) is str and color in adata.layers.keys():
data = adata[adata.obs[cluster].isin(cluster_names), :].layers[color]
color = []
for gene in network.nodes:
c = np.mean(flatten(index_gene(adata, data, [gene])))
color.append(c)
else:
color = None
fig, ax = plt.subplots(figsize=figsize)
ap = ArcPlot(network=network, c=color, s=node_size, cmap=cmap, **kwargs)
node_degree = [network.degree[i] for i in network.nodes]
ap.draw(node_order=node_degree)
if cbar and color is not None:
norm = matplotlib.colors.Normalize(vmin=np.min(color), vmax=np.max(color))
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array(color)
cb = plt.colorbar(
mappable,
cax=set_colorbar(
ax,
{
"width": "12%", # width = 5% of parent_bbox width
"height": "100%", # height : 50%
"loc": "upper right",
"bbox_to_anchor": (0.85, 0.85, 0.145, 0.17),
"borderpad": 1.85,
},
),
ax=ax,
)
if cbar_title is not None:
cb.ax.set_title(cbar_title)
cb.set_alpha(1)
cb.draw_all()
cb.locator = MaxNLocator(nbins=3, integer=True)
cb.update_ticks()
if save_show_or_return == "save":
# Draw a to the screen
plt.autoscale()
s_kwargs = {
"path": None,
"prefix": "arcPlot",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)
save_fig(**s_kwargs)
elif save_show_or_return == "show":
# Draw a to the screen
plt.autoscale()
# Display the plot
plt.show()
# plt.savefig('./unknown_arcplot.pdf', dpi=300)
elif save_show_or_return == "return":
return ap
[docs]def circosPlot(
network: nx.Graph,
node_label_key: str = None,
circos_label_layout: str = "rotate",
node_color_key: str = None,
show_colorbar=True,
edge_lw_scale: float = 0.5,
edge_alpha_scale: float = 0.5,
) -> Axes:
"""wrapper for drawing circos plot via nxviz >= 0.7.3
Parameters
----------
network : nx.Graph
a network graph instance
node_label_key : str, optional
node label (attribute) in network for grouping nodes, by default None
circos_label_layout : str, optional
layout of circos plot (see nxviz docs for details), by default "rotate"
node_color_key : str, optional
node attribute in network, corresponding to color values of nodes, by default None
show_colorbar : bool, optional
whether to show colorbar, by default True
edge_lw_scale : float
the line width scale of edges drawn in in plot
edge_alpha_scale : float
the alpha (opacity, transparency) scale of edges, the value shoud be in [0, 1.0]
"""
ax = nv.circos(
network,
group_by=node_label_key,
node_color_by=node_color_key,
edge_lw_by="weight",
edge_alpha_by="weight",
edge_enc_kwargs={
"lw_scale": edge_lw_scale,
"alpha_scale": edge_alpha_scale,
},
)
nv.annotate.circos_labels(network, group_by=node_label_key, layout=circos_label_layout)
if node_color_key and show_colorbar:
nv.annotate.node_colormapping(
network,
color_by=node_color_key,
legend_kwargs={"loc": "upper right", "bbox_to_anchor": (0.0, 1.0)},
ax=None,
)
return ax
[docs]def circosPlotDeprecated(
adata,
cluster,
cluster_name,
edges_list,
network=None,
weight_scale=5e3,
weight_threshold=1e-4,
figsize=(12, 6),
save_show_or_return="show",
save_kwargs={},
**kwargs,
):
"""Note: this function is written with nxviz old version (<=0.3.x, or higher) API
for the latest nxviz version compatibility, please refer to `dyn.pl.circos_plot`.
Circos plot of gene regulatory network for a particular cell cluster.
Parameters
----------
adata: :class:`~anndata.AnnData`.
AnnData object.
cluster: `str`
The group key that points to the columns of `adata.obs`.
cluster_name: `str` (default: `None`)
The group whose network and arcplot will be constructed and created.
edges_list: `dict` of `pandas.DataFrame`
A dictionary of dataframe of interactions between input genes for each group of cells based on ranking
information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three
columns.
network: class:`~networkx.classes.digraph.DiGraph`
A direct network for this cluster constructed based on Jacobian analysis.
weight_scale: `float` (default: `1e3`)
Because values in Jacobian matrix is often small, the value will be multiplied by the weight_scale so that
the edge will have proper width in display.
weight_threshold: `float` (default: `weight_threshold`)
The threshold of weight that will be used to trim the edges for network reconstruction.
figsize: `None` or `[float, float]` (default: (12, 6)
The width and height of each panel in the figure.
save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
Whether to save, show or return the figure.
save_kwargs: `dict` (default: `{}`)
A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'arcplot', "dpi": None, "ext": 'pdf', "transparent":
True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that
properly modify those keys according to your needs.
**kwargs:
Additional parameters that will eventually pass to CircosPlot.
Returns
-------
Nothing but plot an CircosPlot of the input direct network.
"""
nxvizPlot(
adata,
cluster,
cluster_name,
edges_list,
plot="circosplot",
network=network,
weight_scale=weight_scale,
weight_threshold=weight_threshold,
figsize=figsize,
save_show_or_return=save_show_or_return,
save_kwargs=save_kwargs,
**kwargs,
)
[docs]def hivePlot(
adata,
edges_list,
cluster,
cluster_names=None,
weight_threshold=1e-4,
figsize=(6, 6),
save_show_or_return="show",
save_kwargs={},
):
"""Hive plot of cell cluster specific gene regulatory networks.
Parameters
----------
adata: :class:`~anndata.AnnData`.
AnnData object.
edges_list: `dict` of `pandas.DataFrame`
A dictionary of dataframe of interactions between input genes for each group of cells based on ranking
information of Jacobian analysis. Each composite dataframe has `regulator`, `target` and `weight` three
columns.
cluster: `str`
The group key that points to the columns of `adata.obs`.
cluster_names: `str` (default: `None`)
The group whose network and arcplot will be constructed and created.
weight_threshold: `float` (default: `weight_threshold`)
The threshold of weight that will be used to trim the edges for network reconstruction.
figsize: `None` or `[float, float]` (default: (6, 6)
The width and height of each panel in the figure.
save_show_or_return: `str` {'save', 'show', 'return'} (default: `show`)
Whether to save, show or return the figure.
save_kwargs: `dict` (default: `{}`)
A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the
save_fig function will use the {"path": None, "prefix": 'hiveplot', "dpi": None, "ext": 'pdf',
"transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a
dictionary that properly modify those keys according to your needs.
Returns
-------
Nothing but plot a hive plot of the input cell cluster specific direct network.
"""
# _, has_labeling = (
# adata.uns["pp"].get("has_splicing"),
# adata.uns["pp"].get("has_labeling"),
# )
# layer = "M_s" if not has_labeling else "M_t"
# from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
try:
import networkx as nx
from hiveplotlib import Axis, HivePlot, Node
from hiveplotlib.viz import axes_viz_mpl, edge_viz_mpl, node_viz_mpl
except ImportError:
raise ImportError(
"You need to install the package `networkx, hiveplotlib`."
"install hiveplotlib via `pip install hiveplotlib`"
"install networkx via `pip install nxviz`."
)
reg_groups = list(adata.obs[cluster].unique())
if not set(edges_list.keys()).issubset(reg_groups):
raise ValueError(
f"the edges_list's keys are not equal or subset of the clusters from the " f"adata.obs[{cluster}]"
)
if cluster_names is not None:
reg_groups = list(set(reg_groups).intersection(cluster_names))
if len(reg_groups) == 0:
raise ValueError(
f"the clusters argument {cluster_names} provided doesn't match up with any clusters from the " f"adata."
)
combined_edges, G, edges_dict = None, {}, {}
for i, grp in enumerate(edges_list.keys()):
G[grp] = nx.from_pandas_edgelist(
edges_list[grp].query("weight > @weight_threshold"),
"regulator",
"target",
edge_attr="weight",
create_using=nx.DiGraph(),
)
if len(G[grp].node) == 0:
raise ValueError(
f"weight_threshold is too high, no edge has weight than {weight_threshold} " f"for cluster {grp}."
)
edges_dict[grp] = np.array(G[grp].edges)
combined_edges = edges_list[grp] if combined_edges is None else pd.concat((combined_edges, edges_list[grp]))
# pull out degree information from nodes for later use
combined_G = nx.from_pandas_edgelist(
combined_edges.query("weight > @weight_threshold"),
"regulator",
"target",
edge_attr="weight",
create_using=nx.DiGraph(),
)
edges = np.array(combined_G.edges)
node_ids, degrees = np.unique(edges, return_counts=True)
nodes = []
for node_id, degree in zip(node_ids, degrees):
# store the index number as a way to align the nodes on axes
combined_G.nodes.data()[node_id]["loc"] = node_id
# also store the degree of each node as another way to
# align nodes on axes
combined_G.nodes.data()[node_id]["degree"] = degree
temp_node = Node(unique_id=node_id, data=combined_G.nodes.data()[node_id])
nodes.append(temp_node)
hp = HivePlot()
# nodes ###
hp.add_nodes(nodes)
# axes ###
angles = np.linspace(0, 360, len(reg_groups) + 1)
axes = []
for i, grp in enumerate(reg_groups):
axis = Axis(axis_id=grp, start=1, end=5, angle=angles[i], long_name=grp)
axes.append(axis)
hp.add_axes(axes)
# node assignments ###
nodes = [node.unique_id for node in nodes]
# assign nodes and sorting procedure to position nodes on axis
for i, grp in enumerate(reg_groups):
hp.place_nodes_on_axis(axis_id=grp, unique_ids=nodes, sorting_feature_to_use="degree")
for i, grp in enumerate(reg_groups):
# edges ###
nex_grp = reg_groups[i + 1] if i < len(reg_groups) - 1 else reg_groups[0]
hp.connect_axes(
edges=edges_dict[grp],
axis_id_1=grp,
axis_id_2=nex_grp,
c="C" + str(i),
) # different color for each lineage
# plot axes
fig, ax = axes_viz_mpl(hp, figsize=figsize, axes_labels_buffer=1.4)
# plot nodes
node_viz_mpl(hp, fig=fig, ax=ax, s=80, c="black")
# plot edges
edge_viz_mpl(hive_plot=hp, fig=fig, ax=ax, alpha=0.7, zorder=-1)
# ax.set_title("Hive Plot", fontsize=20, y=0.9)
# custom_lines = [Line2D([0], [0], color=f'C{i}', lw=3, linestyle='-') for i in range(len(reg_groups))]
# ax.legend(custom_lines, reg_groups, loc='upper left', bbox_to_anchor=(0.37, 0.35),
# title="Regulatory network based on Jacobian analysis")
if save_show_or_return == "save":
s_kwargs = {
"path": None,
"prefix": "hiveplot",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)
save_fig(**s_kwargs)
elif save_show_or_return == "show":
plt.tight_layout()
plt.show()
elif save_show_or_return == "return":
return ax