# Source code for dynamo.prediction.fate

```
import itertools
import warnings
from multiprocessing.dummy import Pool as ThreadPool
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from anndata import AnnData
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from ..configuration import DKM
from ..dynamo_logger import (
LoggerManager,
main_info,
main_info_insert_adata,
main_warning,
)
from ..utils import pca_to_expr
from ..tools.connectivity import construct_mapper_umap, correct_hnsw_neighbors, k_nearest_neighbors
from ..tools.utils import fetch_states, getTseq
from ..vectorfield import vector_field_function
from ..vectorfield.utils import vecfld_from_adata, vector_transformation
from .utils import integrate_vf_ivp
[docs]def fate(
adata: AnnData,
init_cells: list,
init_states: Optional[np.ndarray] = None,
basis: Optional[None] = None,
layer: str = "X",
dims: Optional[Union[int, List[int], Tuple[int], np.ndarray]] = None,
genes: Optional[List] = None,
t_end: Optional[float] = None,
direction: str = "both",
interpolation_num: int = 250,
average: bool = False,
sampling: str = "arc_length",
VecFld_true: Callable = None,
inverse_transform: bool = False,
Qkey: str = "PCs",
scale: float = 1,
cores: int = 1,
**kwargs: dict,
) -> AnnData:
"""Predict the historical and future cell transcriptomic states over arbitrary time scales.
This is achieved by integrating the reconstructed vector field function from one or a set of initial cell state(s).
Note that this function is designed so that there is only one trajectory (based on averaged cell states if multiple
initial states are provided) will be returned. `dyn.tl._fate` can be used to calculate multiple cell states.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
init_cells: Cell name or indices of the initial cell states for the historical or future cell state prediction with
numerical integration. If the names in init_cells not found in the adata.obs_name, it will be treated as
cell indices and must be integers.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
basis: The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed
trajectory will be projected back to high dimensional space via the `inverse_transform` function.
layer: Which layer of the data will be used for predicting cell fate with the reconstructed vector field function.
The layer once provided, will override the `basis` argument and then predicting cell fate in high
dimensional space.
dims: The dimensions that will be selected for fate prediction.
genes: The gene names whose gene expression will be used for predicting cell fate. By default (when genes is set to
None), the genes used for velocity embedding (var.use_for_transition) will be used for vector field
reconstruction. Note that the genes to be used need to have velocity calculated and corresponds to those
used in the `dyn.tl.VectorField` function.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
direction: The direction to predict the cell fate. One of the `forward`, `backward` or `both` string.
interpolation_num: The number of uniformly interpolated time points.
average: The method to calculate the average cell state at each time step, can be one of `origin` or `trajectory`. If
`origin` used, the average expression state from the init_cells will be calculated and the fate prediction
is based on this state. If `trajectory` used, the average expression states of all cells predicted from the
vector field function at each time point will be used. If `average` is `False`, no averaging will be
applied. If `average` is True, `origin` will be used.
sampling: Methods to sample points along the integration path, one of `{'arc_length', 'logspace', 'uniform_indices'}`.
If `logspace`, we will sample time points linearly on log space. If `uniform_indices`, the sorted unique set
of all time points from all cell states' fate prediction will be used and then evenly sampled up to
`interpolation_num` time points. If `arc_length`, we will sample the integration path with uniform arc
length.
VecFld_true: The true ODE function, useful when the data is generated through simulation. Replace VecFld argument when
this has been set.
inverse_transform: Whether to inverse transform the low dimensional vector field prediction back to high dimensional space.
Qkey: The key of the PCA loading matrix in `.uns`.
scale: The value that will be used to scale the predicted velocity value from the reconstructed vector field
function.
cores: Number of cores to calculate path integral for predicting cell fate. If cores is set to be > 1,
multiprocessing will be used to parallel the fate prediction.
kwargs: Additional parameters that will be passed into the fate function.
Returns:
adata: AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns
attribute.
"""
if basis is not None:
fate_key = "fate_" + basis
# vf_key = "VecFld_" + basis
else:
fate_key = "fate" if layer == "X" else "fate_" + layer
# vf_key = "VecFld"
# VecFld = adata.uns[vf_key]["VecFld"]
# X = VecFld["X"]
# xmin, xmax = X.min(0), X.max(0)
# t_end = np.max(xmax - xmin) / np.min(np.abs(VecFld["V"]))
# valid_genes = None
init_states, VecFld, t_end, valid_genes = fetch_states(
adata,
init_states,
init_cells,
basis,
layer,
average,
t_end,
)
if np.isscalar(dims):
init_states = init_states[:, :dims]
elif dims is not None:
init_states = init_states[:, dims]
vf = (
(lambda x: scale * vector_field_function(x=x, vf_dict=VecFld, dim=dims)) if VecFld_true is None else VecFld_true
)
t, prediction = _fate(
vf,
init_states,
t_end=t_end,
direction=direction,
interpolation_num=interpolation_num,
average=True if average == "trajectory" else False,
sampling=sampling,
cores=cores,
**kwargs,
)
exprs = None
if basis == "pca" and inverse_transform:
Qkey = "PCs"
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
high_p_n = exprs[0].shape[1]
else:
exprs = vector_transformation(prediction.T, adata.uns[Qkey])
high_p_n = exprs.shape[1]
if adata.var.use_for_dynamics.sum() == high_p_n:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
else:
valid_genes = adata.var_names[adata.var.use_for_transition]
elif basis == "umap" and inverse_transform:
# this requires umap 0.4; reverse project to PCA space.
if hasattr(prediction, "ndim"):
if prediction.ndim == 1:
prediction = prediction[None, :]
params = adata.uns["umap_fit"]
umap_fit = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)
PCs = adata.uns["PCs"].T
exprs = []
for cur_pred in prediction:
expr = umap_fit.inverse_transform(cur_pred.T)
# further reverse project back to raw expression space
if PCs.shape[0] == expr.shape[1]:
expr = np.expm1(expr @ PCs + adata.uns["pca_mean"])
exprs.append(expr)
if adata.var.use_for_dynamics.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
elif adata.var.use_for_transition.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_transition]
else:
raise Exception(
"looks like a customized set of genes is used for pca analysis of the adata. "
"Try rerunning pca analysis with default settings for this function to work."
)
adata.uns[fate_key] = {
"init_states": init_states,
"init_cells": list(init_cells),
"average": average,
"t": t,
"prediction": prediction,
# "VecFld": VecFld,
# "VecFld_true": VecFld_true,
"genes": valid_genes,
}
if exprs is not None:
adata.uns[fate_key]["exprs"] = exprs
return adata
def _fate(
VecFld: Callable,
init_states: np.ndarray,
t_end: Optional[float] = None,
step_size: Optional[float] = None,
direction: str = "both",
interpolation_num: int = 250,
average: bool = True,
sampling: str = "arc_length",
cores: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""Predict the historical and future cell transcriptomic states over arbitrary time scales by integrating vector
field functions from one or a set of initial cell state(s).
Args:
VecFld: Functional form of the vector field reconstructed from sparse single cell samples. It is applicable to the
entire transcriptomic space.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
step_size: Step size for integrating the future or history cell state, used by the odeint function. By default it is
None, and the step_size will be automatically calculated to ensure 250 total integration time-steps will be
used.
direction: The direction to predict the cell fate. One of the `forward`, `backward`or `both` string.
interpolation_num: The number of uniformly interpolated time points.
average: A boolean flag to determine whether to smooth the trajectory by calculating the average cell state at each
time step.
sampling: Methods to sample points along the integration path, one of `{'arc_length', 'logspace', 'uniform_indices'}`.
If `logspace`, we will sample time points linearly on log space. If `uniform_indices`, the sorted unique set
of all time points from all cell states' fate prediction will be used and then evenly sampled up to
`interpolation_num` time points. If `arc_length`, we will sample the integration path with uniform arc
length.
cores: Number of cores to calculate path integral for predicting cell fate. If cores is set to be > 1,
multiprocessing will be used to parallel the fate prediction.
Returns:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features). n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings. Of note, if the average is set to be True, the average cell state at each time point is calculated for all cells.
"""
if sampling == "uniform_indices":
main_warning(
f"Uniform_indices method sample data points from all time points. The multiprocessing will be disabled."
)
cores = 1
t_linspace = getTseq(init_states, t_end, step_size)
if cores == 1:
t, prediction = integrate_vf_ivp(
init_states,
t_linspace,
direction,
VecFld,
interpolation_num=interpolation_num,
average=average,
sampling=sampling,
)
else:
pool = ThreadPool(cores)
res = pool.starmap(
integrate_vf_ivp,
zip(
init_states,
itertools.repeat(t_linspace),
itertools.repeat(direction),
itertools.repeat(VecFld),
itertools.repeat(()),
itertools.repeat(interpolation_num),
itertools.repeat(average),
itertools.repeat(sampling),
itertools.repeat(False),
itertools.repeat(True),
),
) # disable tqdm when using multiple cores.
pool.close()
pool.join()
t_, prediction_ = zip(*res)
t, prediction = [i[0] for i in t_], [i[0] for i in prediction_]
if init_states.shape[0] > 1 and average:
t_stack, prediction_stack = np.hstack(t), np.hstack(prediction)
n_cell, n_feature = init_states.shape
t_len = int(len(t_stack) / n_cell)
avg = np.zeros((n_feature, t_len))
for i in range(t_len):
avg[:, i] = np.mean(prediction_stack[:, np.arange(n_cell) * t_len + i], 1)
prediction = [avg]
t = [np.sort(np.unique(t))]
return t, prediction
[docs]def fate_bias(
adata: AnnData,
group: str,
basis: str = "umap",
inds: Union[list, None] = None,
use_sink_percentage: bool = True,
step_used_percentage: Optional[float] = None,
speed_percentile: float = 5,
dist_threshold: Optional[float] = None,
source_groups: Optional[list] = None,
metric: str = "euclidean",
metric_kwds: dict = None,
cores: int = 1,
seed: int = 19491001,
**kwargs,
) -> pd.DataFrame:
"""Calculate the lineage (fate) bias of states whose trajectory are predicted.
Fate bias is currently calculated as the percentage of points along the predicted cell fate trajectory whose
distance to their 0-th nearest neighbors on the data are close enough (determined by median 1-st nearest neighbors
of all observed cells and the dist_threshold) to any cell from each group specified by `group` key. The details is
described as following:
Cell fate predicted by our vector field method sometimes end up in regions that are not sampled with cells. We thus
developed a heuristic method to iteratively walk backward the integration path to assign cell fate. We first
identify the regions with small velocity in the tail of the integration path (determined by `speed_percentile`),
then we check whether the distance of 0-th nearest points on the observed data to all those points are far away from
the observed data (determined by `dist_threshold`). If they are not all close to data, we then walk backwards along
the trajectory by one time step until the distance of any currently visited integration path’s data points’ 0-th
nearest points to the observed cells is close enough. In order to calculate the cell fate probability, we diffuse
one step further of the identified nearest neighbors from the integration to identify more nearest observed cells,
especially those from terminal cell types in case nearby cells first identified are all close to some random
progenitor cells. Then we use group information of those observed cells to define the fate probability.
`fate_bias` calculate a confidence score for the calculated fate probability with a simple metric, defined as
:math:`1 - (sum(distances > dist_threshold * median_dist) + walk_back_steps) / (len(indices) + walk_back_steps)`
The `distance` is currently visited integration path’s data points’ 0-th nearest points to the observed cells.
`median_dist` is median distance of their 1-st nearest cell distance of all observed cells. `walk_back_steps` is the
steps walked backward along the integration path until all currently visited integration points's 0-th nearest
points to the observed cells satisfy the distance threshold. `indices` are the time indices of integration points
that is regarded as the regions with `small velocity` (note when walking backward, those corresponding points do
not necessarily have small velocity anymore).
Args:
adata: AnnData object that contains the predicted fate trajectories in the `uns` attribute.
group: The column key that corresponds to the cell type or other group information for quantifying the bias of cell
state.
basis: The embedding data space where cell fates were predicted and cell fates bias will be quantified.
inds: The indices of the time steps that will be used for calculating fate bias.
Otherwise inds need to be a list of integers of the time steps.
use_sink_percentage: If inds is None and use_sink is True, sink calculation will be applied to calculate
indices used for fate bias calculation
step_used_percentage: If inds is None and step_used_percentage is not None,
step_used_percentage will be regarded as a percentage,
and the LAST step_used_percentage of steps will be used for fate bias calculation.
speed_percentile: The percentile of speed that will be used to determine the terminal cells (or sink region on the prediction
path where speed is smaller than this speed percentile).
dist_threshold: A multiplier of the median nearest cell distance on the embedding to determine cells that are outside the
sampled domain of cells. If the mean distance of identified "terminal cells" is above this number, we will
look backward along the trajectory (by minimize all indices by 1) until it finds cells satisfy this
threshold. By default it is set to be 1 to ensure only considering points that are very close to observed
data points.
source_groups: The groups that corresponds to progenitor groups. They need to have at least one intersection with the groups
from the `group` column. If group is not `None`, any identified "source_groups" cells that happen to be in
those groups will be ignored and the probability of cell fate of those cells will be reassigned to the group
that has the highest fate probability among other non source_groups group cells.
metric: The distance metric to use for the tree. The default metric with p=2 is equivalent to the standard
Euclidean metric. See the documentation of :class:`DistanceMetric` for a list of available metrics. If
metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a
:term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors.
metric_kwds : Additional keyword arguments for the metric function.
cores: The number of parallel jobs to run for neighbors search. ``None`` means 1 unless in a
:obj:`joblib.parallel_backend` context. ``-1`` means using all processors.
seed: Random seed to ensure the reproducibility of each run.
kwargs: Additional arguments that will be passed to each nearest neighbor search algorithm.
Returns:
fate_bias: A DataFrame that stores the fate bias for each cell state (row) to each cell group (column).
"""
if dist_threshold is None:
dist_threshold = 1
if group not in adata.obs.keys():
raise ValueError(f"The group {group} you provided is not a key of .obs attribute.")
else:
clusters = adata.obs[group]
basis_key = "X_" + basis if basis is not None else "X"
fate_key = "fate_" + basis if basis is not None else "fate"
if basis_key not in adata.obsm.keys():
raise ValueError(f"The basis {basis_key} you provided is not a key of .obsm attribute.")
if fate_key not in adata.uns.keys():
raise ValueError(
f"The {fate_key} key is not existed in the .uns attribute of the adata object. You need to run"
f"dyn.pd.fate(adata, basis='{basis}') before calculate fate bias."
)
if source_groups is not None:
if type(source_groups) is str:
source_groups = [source_groups]
source_groups = list(set(source_groups).intersection(clusters))
if len(source_groups) == 0:
raise ValueError(
f"the {source_groups} you provided doesn't intersect with any groups in the {group} column."
)
X = adata.obsm[basis_key] if basis_key != "X" else adata.X
knn, distances, nbrs, alg = k_nearest_neighbors(
X,
k=29,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
return_nbrs=True,
n_jobs=cores,
**kwargs,
)
median_dist = np.median(distances[:, 1])
pred_dict = {}
cell_predictions, cell_indx = (
adata.uns[fate_key]["prediction"],
adata.uns[fate_key]["init_cells"],
)
t = adata.uns[fate_key]["t"]
confidence = np.zeros(len(t))
for i, prediction in tqdm(enumerate(cell_predictions), desc="calculating fate distributions"):
cur_t, n_steps = t[i], len(t[i])
# Generate or set indices as step sample points. Meanwhile ensure
# identifying sink where the speed is very slow. If "inds" is set, use "inds" and the speed_percentile is used to determine the time indicies for calculating the fate bias
# else if "use_sink_percentage" is set, calculate avg_speed and sink_checker
# else if "step_used_percentage" is set, use the last percentage of steps to check for cell fate bias.
# If none of the above arguments are set, use a list of n steps as indices
if inds is not None:
indices = inds
elif use_sink_percentage:
avg_speed = np.array([np.linalg.norm(i) for i in np.diff(prediction, 1).T]) / np.diff(cur_t)
sink_checker = np.where(avg_speed[::-1] > np.percentile(avg_speed, speed_percentile))[0]
indices = np.arange(n_steps - max(min(sink_checker), 10), n_steps)
elif step_used_percentage is float:
indices = np.arange(int(n_steps - step_used_percentage * n_steps), n_steps)
else:
main_info("using all steps data")
indices = np.arange(0, n_steps)
if alg == "pynn":
knn, distances = nbrs.query(prediction[:, indices].T, k=30)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(prediction[:, indices].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices].T)
# if final steps too far away from observed cells, ignore them
walk_back_steps = 0
while True:
is_dist_larger_than_threshold = distances.flatten() < dist_threshold * median_dist
if any(is_dist_larger_than_threshold):
# let us diffuse one step further to identify cells from terminal cell types in case
# cells with indices are all close to some random progenitor cells.
if hasattr(nbrs, "query"):
knn, _ = nbrs.query(X[knn.flatten(), :], k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances_hn = nbrs.knn_query(X[knn.flatten(), :], k=30)
knn, _ = correct_hnsw_neighbors(knn, distances_hn)
else:
_, knn = nbrs.kneighbors(X[knn.flatten(), :])
fate_prob = clusters[knn.flatten()].value_counts() / len(knn.flatten())
if source_groups is not None:
source_p = fate_prob[source_groups].sum()
if 1 > source_p > 0:
fate_prob[source_groups] = 0
fate_prob[fate_prob.idxmax()] += source_p
pred_dict[i] = fate_prob
confidence[i] = 1 - (sum(~is_dist_larger_than_threshold) + walk_back_steps) / (
len(is_dist_larger_than_threshold) + walk_back_steps
)
break
else:
walk_back_steps += 1
if any(indices - 1 < 0):
pred_dict[i] = clusters[knn.flatten()].value_counts() * np.nan
break
if hasattr(nbrs, "query"):
knn, distances = nbrs.query(prediction[:, indices - 1].T, k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances = nbrs.knn_query(prediction[:, indices - 1].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices - 1].T)
knn, distances = knn[:, 0], distances[:, 0]
indices = indices - 1
bias = pd.DataFrame(pred_dict).T
conf = pd.DataFrame({"confidence": confidence}, index=bias.index)
bias = pd.merge(conf, bias, left_index=True, right_index=True)
if cell_indx is not None:
bias.index = cell_indx
return bias
# def fate_(adata, time, direction = 'forward'):
# from .moments import *
# gene_exprs = adata.X
# cell_num, gene_num = gene_exprs.shape
#
#
# for i in range(gene_num):
# params = {'a': adata.uns['dynamo'][i, "a"], \
# 'b': adata.uns['dynamo'][i, "b"], \
# 'la': adata.uns['dynamo'][i, "la"], \
# 'alpha_a': adata.uns['dynamo'][i, "alpha_a"], \
# 'alpha_i': adata.uns['dynamo'][i, "alpha_i"], \
# 'sigma': adata.uns['dynamo'][i, "sigma"], \
# 'beta': adata.uns['dynamo'][i, "beta"], \
# 'gamma': adata.uns['dynamo'][i, "gamma"]}
# mom = moments_simple(**params)
# for j in range(cell_num):
# x0 = gene_exprs[i, j]
# mom.set_initial_condition(*x0)
# if direction == "forward":
# gene_exprs[i, j] = mom.solve([0, time])
# elif direction == "backward":
# gene_exprs[i, j] = mom.solve([0, - time])
#
# adata.uns['prediction'] = gene_exprs
# return adata
[docs]def andecestor(
adata: AnnData,
init_cells: List,
init_states: Optional[np.ndarray] = None,
cores: int = 1,
t_end: int = 50,
basis: str = "umap",
n_neighbors: int = 5,
direction: str = "forward",
interpolation_num: int = 250,
last_point_only: bool = False,
metric: str = "euclidean",
metric_kwds: dict = None,
seed: int = 19491001,
**kwargs,
) -> None:
"""Predict the ancestors or descendants of a group of initial cells (states) with the given vector field function.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
init_cells: Cell name or indices of the initial cell states for the historical or future cell state prediction with
numerical integration. If the names in init_cells not found in the adata.obs_name, it will be treated as
cell indices and must be integers.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
basis: The key in `adata.obsm` that points to the embedding data to use for predicting cell fate.
cores: Number of cores to calculate nearest neighbor graph.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
n_neighbors: Number of nearest neighbors.
direction: The direction to predict the cell fate. One of the `forward`, `backward` or `both` string.
interpolation_num: The number of uniformly interpolated time points.
metric: The distance metric to use for the tree. The default metric is 'euclidean', and with p=2 is
equivalent to the standard Euclidean metric. See the documentation of :class:`DistanceMetric` for
a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements
may be considered neighbors.
metric_kwds : Additional keyword arguments for the metric function.
seed: Random seed to ensure the reproducibility of each run.
kwargs: Additional arguments that will be passed to each nearest neighbor search algorithm.
Returns:
Nothing but update the adata object with a new column in `.obs` that stores predicted ancestors or descendants.
"""
logger = LoggerManager.gen_logger("dynamo-andecestor")
logger.log_time()
main_info("retrieve vector field function.")
vec_dict, vecfld = vecfld_from_adata(adata, basis=basis)
basis_key = "X_" + basis
X = adata.obsm[basis_key].copy()
main_info("build a kNN graph structure so we can query the nearest cells of the predicted states.")
_, _, nbrs, alg = k_nearest_neighbors(
X,
k=n_neighbors - 1,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
n_jobs=cores,
return_nbrs=True,
logger=logger,
**kwargs,
)
if init_states is None:
init_states = adata[init_cells, :].obsm[basis_key]
else:
if init_states.shape[1] != adata.obsm[basis_key].shape[1]:
raise Exception(
f"init_states has to have the same columns as adata.obsm[{basis_key}] but you have "
f"{init_states.shape[1]}"
)
main_info("predict cell state trajectory via integrating vector field function.")
t, pred = _fate(
vecfld,
init_states,
t_end=t_end,
interpolation_num=interpolation_num,
average=False,
sampling="arc_length",
cores=cores,
direction=direction,
)
nearest_cell_inds = []
main_info("identify the progenitors/descendants by finding predicted cell states' nearest cells.")
for j in range(len(pred)):
last_indices = [0, -1] if direction == "both" else [-1]
queries = pred[j].T[last_indices] if last_point_only else pred[j].T
if alg == "pynn":
knn, distances = nbrs.query(queries, k=n_neighbors)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(queries, k=n_neighbors)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(queries)
nearest_cell_inds += list(knn.flatten())
nearest_cell_inds = np.unique(nearest_cell_inds)
if init_cells is not None:
if type(init_cells[0]) is int:
init_cells = adata.obs_names[init_cells]
nearest_cells = list(set(adata.obs_names[nearest_cell_inds]).difference(init_cells))
else:
nearest_cells = list(adata.obs_names[nearest_cell_inds])
obs_key = "descendant" if direction == "forward" else "ancestor" if direction == "backward" else "lineage"
main_info_insert_adata(obs_key)
adata.obs[obs_key] = False
adata.obs.loc[nearest_cells, obs_key] = True
logger.finish_progress(progress_name=f"predict {obs_key}")
```