from typing import Callable, Optional, Tuple, Union
import numpy as np
import scipy
from anndata import AnnData
from scipy.interpolate import interp1d
from ..dynamo_logger import LoggerManager
from ..tools.utils import flatten
from ..utils import expr_to_pca, pca_to_expr
from ..vectorfield.scVectorField import DifferentiableVectorField
from ..vectorfield.topography import dup_osc_idx_iter
from ..vectorfield.utils import angle, normalize_vectors
[docs]class Trajectory:
"""Base class for handling trajectory interpolation, resampling, etc."""
def __init__(self, X: np.ndarray, t: Union[None, np.ndarray] = None, sort: bool = True) -> None:
"""Initializes a Trajectory object.
Args:
X: trajectory positions, shape (n_points, n_dimensions)
t: trajectory times, shape (n_points,). Defaults to None.
sort: whether to sort the time stamps. Defaults to True.
"""
self.X = X
if t is None:
self.t = None
else:
self.set_time(t, sort=sort)
def __len__(self) -> int:
"""Returns the number of points in the trajectory.
Returns:
number of points in the trajectory
"""
return self.X.shape[0]
[docs] def set_time(self, t: np.ndarray, sort: bool = True) -> None:
"""Set the time stamps for the trajectory. Sorts the time stamps if requested.
Args:
t: trajectory times, shape (n_points,)
sort: whether to sort the time stamps. Defaults to True.
"""
if sort:
I = np.argsort(t)
self.t = t[I]
self.X = self.X[I]
else:
self.t = t
[docs] def dim(self) -> int:
"""Returns the number of dimensions in the trajectory.
Returns:
number of dimensions in the trajectory
"""
return self.X.shape[1]
[docs] def calc_tangent(self, normalize: bool = True):
"""Calculate the tangent vectors of the trajectory.
Args:
normalize: whether to normalize the tangent vectors. Defaults to True.
Returns:
tangent vectors of the trajectory, shape (n_points-1, n_dimensions)
"""
tvec = self.X[1:] - self.X[:-1]
if normalize:
tvec = normalize_vectors(tvec)
return tvec
[docs] def calc_arclength(self) -> float:
"""Calculate the arc length of the trajectory.
Returns:
arc length of the trajectory
"""
tvec = self.calc_tangent(normalize=False)
norms = np.linalg.norm(tvec, axis=1)
return np.sum(norms)
[docs] def calc_curvature(self) -> np.ndarray:
"""Calculate the curvature of the trajectory.
Returns:
curvature of the trajectory, shape (n_points,)
"""
tvec = self.calc_tangent(normalize=False)
kappa = np.zeros(self.X.shape[0])
for i in range(1, self.X.shape[0] - 1):
# ref: http://www.cs.jhu.edu/~misha/Fall09/1-curves.pdf (p. 55)
kappa[i] = angle(tvec[i - 1], tvec[i]) / (np.linalg.norm(tvec[i - 1] / 2) + np.linalg.norm(tvec[i] / 2))
return kappa
[docs] def resample(self, n_points: int, tol: float = 1e-4, inplace: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""Resample the curve with the specified number of points.
Args:
n_points: An integer specifying the number of points in the resampled curve.
tol: A float specifying the tolerance for removing redundant points. Default is 1e-4.
inplace: A boolean flag indicating whether to modify the curve object in place. Default is True.
Returns:
A tuple containing the resampled curve coordinates and time values (if available).
Raises:
ValueError: If the specified number of points is less than 2.
TODO:
Decide whether the tol argument should be included or not during the code refactoring and optimization.
"""
# remove redundant points
"""if tol is not None:
X, arclen, discard = remove_redundant_points_trajectory(self.X, tol=tol, output_discard=True)
if self.t is not None:
t = np.array(self.t[~discard], copy=True)
else:
t = None
else:
X = np.array(self.X, copy=True)
t = np.array(self.t, copy=True) if self.t is not None else None
arclen = self.calc_arclength()"""
# resample using the arclength sampling
# ret = arclength_sampling(X, arclen / n_points, t=t)
ret = arclength_sampling_n(self.X, n_points, t=self.t)
X = ret[0]
if self.t is not None:
t = ret[2]
if inplace:
self.X, self.t = X, t
return X, t
[docs] def archlength_sampling(
self,
sol: scipy.integrate._ivp.common.OdeSolution,
interpolation_num: int,
integration_direction: str,
) -> None:
"""Sample the curve using archlength sampling.
Args:
sol: The ODE solution from scipy.integrate.solve_ivp.
interpolation_num: The number of points to interpolate the curve at.
integration_direction: The direction to integrate the curve in. Can be "forward", "backward", or "both".
"""
tau, x = self.t, self.X.T
idx = dup_osc_idx_iter(x, max_iter=100, tol=x.ptp(0).mean() / 1000)[0]
# idx = dup_osc_idx_iter(x)
x = x[:idx]
_, arclen, _ = remove_redundant_points_trajectory(x, tol=1e-4, output_discard=True)
cur_Y, alen, self.t = arclength_sampling_n(x, num=interpolation_num+1, t=tau[:idx])
self.t = self.t[1:]
cur_Y = cur_Y[:, 1:]
if integration_direction == "both":
neg_t_len = sum(np.array(self.t) < 0)
self.X = (
sol(self.t)
if integration_direction != "both"
else np.hstack(
(
sol[0](self.t[:neg_t_len]),
sol[1](self.t[neg_t_len:]),
)
)
)
[docs] def logspace_sampling(
self,
sol: scipy.integrate._ivp.common.OdeSolution,
interpolation_num: int,
integration_direction: str,
) -> None:
"""Sample the curve using logspace sampling.
Args:
sol: The ODE solution from scipy.integrate.solve_ivp.
interpolation_num: The number of points to interpolate the curve at.
integration_direction: The direction to integrate the curve in. Can be "forward", "backward", or "both".
"""
tau, x = self.t, self.X.T
neg_tau, pos_tau = tau[tau < 0], tau[tau >= 0]
if len(neg_tau) > 0:
t_0, t_1 = (
-(
np.logspace(
0,
np.log10(abs(min(neg_tau)) + 1),
interpolation_num,
)
)
- 1,
np.logspace(0, np.log10(max(pos_tau) + 1), interpolation_num) - 1,
)
self.t = np.hstack((t_0[::-1], t_1))
else:
self.t = np.logspace(0, np.log10(max(tau) + 1), interpolation_num) - 1
if integration_direction == "both":
neg_t_len = sum(np.array(self.t) < 0)
self.X = (
sol(self.t)
if integration_direction != "both"
else np.hstack(
(
sol[0](self.t[:neg_t_len]),
sol[1](self.t[neg_t_len:]),
)
)
)
[docs] def interpolate(self, t: np.ndarray, **interp_kwargs) -> np.ndarray:
"""Interpolate the curve at new time values.
Args:
t: The new time values at which to interpolate the curve.
**interp_kwargs: Additional arguments to pass to `scipy.interpolate.interp1d`.
Returns:
The interpolated values of the curve at the specified time values.
Raises:
Exception: If `self.t` is `None`, which is needed for interpolation.
"""
if self.t is None:
raise Exception("`self.t` is `None`, which is needed for interpolation.")
return interp1d(self.t, self.X, axis=0, **interp_kwargs)(t)
[docs] def interp_t(self, num: int = 100) -> np.ndarray:
"""Interpolates the `t` parameter linearly.
Args:
num: Number of interpolation points.
Returns:
The array of interpolated `t` values.
"""
if self.t is None:
raise Exception("`self.t` is `None`, which is needed for interpolation.")
return np.linspace(self.t[0], self.t[-1], num=num)
[docs] def interp_X(self, num: int = 100, **interp_kwargs) -> np.ndarray:
"""Interpolates the curve at `num` equally spaced points in `t`.
Args:
num: The number of points to interpolate the curve at.
**interp_kwargs: Additional keyword arguments to pass to `scipy.interpolate.interp1d`.
Returns:
The interpolated curve at `num` equally spaced points in `t`.
"""
if self.t is None:
raise Exception("`self.t` is `None`, which is needed for interpolation.")
return self.interpolate(self.interp_t(num=num), **interp_kwargs)
[docs] def integrate(self, func: Callable) -> np.ndarray:
"""Calculate the integral of a function along the curve.
Args:
func: A function to integrate along the curve.
Returns:
The integral of func along the discrete curve.
"""
F = np.zeros(func(self.X[0]).shape)
tvec = self.calc_tangent(normalize=False)
for i in range(1, self.X.shape[0] - 1):
# ref: http://www.cs.jhu.edu/~misha/Fall09/1-curves.pdf P. 47
F += func(self.X[i]) * (np.linalg.norm(tvec[i - 1]) + np.linalg.norm(tvec[i])) / 2
return F
[docs] def calc_msd(self, decomp_dim: bool = True, ref: int = 0) -> Union[float, np.ndarray]:
"""Calculate the mean squared displacement (MSD) of the curve with respect to a reference point.
Args:
decomp_dim: If True, return the MSD of each dimension separately. If False, return the total MSD.
ref: Index of the reference point. Default is 0.
Returns:
The MSD of the curve with respect to the reference point.
"""
S = (self.X - self.X[ref]) ** 2
if decomp_dim:
S = S.sum(axis=0)
else:
S = S.sum()
S /= len(self)
return S
class VectorFieldTrajectory(Trajectory):
"""Class for handling trajectory data with a differentiable vector field."""
def __init__(self, X: np.ndarray, t: np.ndarray, vecfld: DifferentiableVectorField) -> None:
"""Initializes a VectorFieldTrajectory object.
Args:
X: The trajectory data as a numpy array of shape (n, d).
t: The time data as a numpy array of shape (n,).
vecfld: The differentiable vector field that describes the trajectory.
"""
super().__init__(X, t=t)
self.vecfld = vecfld
self.data = {"velocity": None, "acceleration": None, "curvature": None, "divergence": None}
self.Js = None
def get_velocities(self) -> np.ndarray:
"""Calculates and returns the velocities along the trajectory.
Returns:
The velocity data as a numpy array of shape (n, d).
"""
if self.data["velocity"] is None:
self.data["velocity"] = self.vecfld.func(self.X)
return self.data["velocity"]
def get_jacobians(self, method: Optional[str] = None) -> np.ndarray:
"""Calculates and returns the Jacobians of the vector field along the trajectory.
Args:
method: The method used to compute the Jacobians.
Returns:
The Jacobian data as a numpy array of shape (n, d, d).
"""
if self.Js is None:
fjac = self.vecfld.get_Jacobian(method=method)
self.Js = fjac(self.X)
return self.Js
def get_accelerations(self, method: Optional[str] = None, **kwargs) -> np.ndarray:
"""Calculates and returns the accelerations along the trajectory.
Args:
method: The method used to compute the Jacobians.
**kwargs: Additional keyword arguments to be passed to the acceleration computation method.
Returns:
The acceleration data as a numpy array of shape (n, d).
"""
if self.data["acceleration"] is None:
if self.Js is None:
self.Js = self.get_jacobians(method=method)
self.data["acceleration"] = self.vecfld.compute_acceleration(self.X, Js=self.Js, **kwargs)
return self.data["acceleration"]
def get_curvatures(self, method: Optional[str] = None, **kwargs) -> np.ndarray:
"""Calculates and returns the curvatures along the trajectory.
Args:
method: The method used to compute the Jacobians.
**kwargs: Additional keyword arguments to be passed to the curvature computation method.
Returns:
The curvature data as a numpy array of shape (n,).
"""
if self.data["curvature"] is None:
if self.Js is None:
self.Js = self.get_jacobians(method=method)
self.data["curvature"] = self.vecfld.compute_curvature(self.X, Js=self.Js, **kwargs)
return self.data["curvature"]
def get_divergences(self, method: Optional[str] = None, **kwargs) -> np.ndarray:
"""Calculates and returns the divergences along the trajectory.
Args:
method: The method used to compute the Jacobians.
**kwargs: Additional keyword arguments to be passed to the divergence computation method.
Returns:
The divergence data as a numpy array of shape (n,).
"""
if self.data["divergence"] is None:
if self.Js is None:
self.Js = self.get_jacobians(method=method)
self.data["divergence"] = self.vecfld.compute_divergence(self.X, Js=self.Js, **kwargs)
return self.data["divergence"]
def calc_vector_msd(self, key: str, decomp_dim: bool = True, ref: int = 0) -> Union[np.ndarray, float]:
"""Calculate and return the mean squared displacement of a given vector field attribute in the trajectory.
Args:
key: The key for the vector field attribute in self.data to compute the mean squared displacement of.
decomp_dim: Whether to decompose the MSD by dimension. Defaults to True.
ref: The index of the reference point to use for computing the MSD. Defaults to 0.
Returns:
The mean squared displacement of the specified vector component in the trajectory.
TODO:
Discuss should we also calculate other quantities during the code refactoring and
optimization phase (e.g. curl, hessian, laplacian, etc).
"""
V = self.data[key]
S = (V - V[ref]) ** 2
if decomp_dim:
S = S.sum(axis=0)
else:
S = S.sum()
S /= len(self)
return S
[docs]class GeneTrajectory(Trajectory):
"""Class for handling gene expression trajectory data."""
def __init__(
self,
adata: AnnData,
X: Optional[np.ndarray] = None,
t: Optional[np.ndarray] = None,
X_pca: Optional[np.ndarray] = None,
PCs: str = "PCs",
mean: str = "pca_mean",
genes: str = "use_for_pca",
expr_func: Optional[Callable] = None,
**kwargs,
) -> None:
"""Initializes a GeneTrajectory object.
Args:
adata: Anndata object containing the gene expression data.
X: The gene expression data as a numpy array of shape (n, d). Defaults to None.
t: The time data as a numpy array of shape (n,). Defaults to None.
X_pca: The PCA-transformed gene expression data as a numpy array of shape (n, d). Defaults to None.
PCs: The key in adata.uns to use for the PCA components. Defaults to "PCs".
mean: The key in adata.uns to use for the PCA mean. Defaults to "pca_mean".
genes: The key in adata.var to use for the genes. Defaults to "use_for_pca".
expr_func: A function to transform the PCA-transformed gene expression data back to the original space.
Defaults to None.
**kwargs: Additional keyword arguments to be passed to the superclass initializer.
"""
self.adata = adata
if type(PCs) is str:
PCs = self.adata.uns[PCs]
self.PCs = PCs
if type(mean) is str:
mean = self.adata.uns[mean]
self.mean = mean
self.expr_func = expr_func
if type(genes) is str:
genes = adata.var_names[adata.var[genes]].to_list()
self.genes = np.array(genes)
if X_pca is not None:
self.from_pca(X_pca, t=t, **kwargs)
if X is not None:
super().__init__(X, t=t)
[docs] def from_pca(self, X_pca: np.ndarray, t: Optional[np.ndarray] = None) -> None:
"""Converts PCA-transformed gene expression data to gene expression data.
Args:
X_pca: The PCA-transformed gene expression data as a numpy array of shape (n, d).
t: The time data as a numpy array of shape (n,). Defaults to None.
"""
X = pca_to_expr(X_pca, self.PCs, mean=self.mean, func=self.expr_func)
super().__init__(X, t=t)
[docs] def to_pca(self, x: Optional[np.ndarray] = None) -> np.ndarray:
"""Converts gene expression data to PCA-transformed gene expression data.
Args:
x: The gene expression data as a numpy array of shape (n, d). Defaults to None.
Returns:
The PCA-transformed gene expression data as a numpy array of shape (n, d).
"""
if x is None:
x = self.X
return expr_to_pca(x, self.PCs, mean=self.mean, func=self.expr_func)
[docs] def genes_to_mask(self) -> np.ndarray:
"""Returns a boolean mask for the genes in the trajectory.
Returns:
A boolean mask for the genes in the trajectory.
"""
mask = np.zeros(self.adata.n_vars, dtype=np.bool_)
for g in self.genes:
mask[self.adata.var_names == g] = True
return mask
[docs] def calc_msd(self, save_key: str = "traj_msd", **kwargs) -> Union[float, np.ndarray]:
"""Calculate the mean squared displacement (MSD) of the gene expression trajectory.
Args:
save_key: The key to save the MSD data to in adata.var. Defaults to "traj_msd".
**kwargs: Additional keyword arguments to be passed to the superclass method.
Returns:
The mean squared displacement of the gene expression trajectory.
"""
msd = super().calc_msd(**kwargs)
LoggerManager.main_logger.info_insert_adata(save_key, "var")
self.adata.var[save_key] = np.ones(self.adata.n_vars) * np.nan
self.adata.var[save_key][self.genes_to_mask()] = msd
return msd
[docs] def save(self, save_key: str = "gene_trajectory") -> None:
"""Save the gene expression trajectory to adata.var.
Args:
save_key: The key to save the gene expression trajectory to in adata.var. Defaults to "gene_trajectory".
"""
LoggerManager.main_logger.info_insert_adata(save_key, "varm")
self.adata.varm[save_key] = np.ones((self.adata.n_vars, self.X.shape[0])) * np.nan
self.adata.varm[save_key][self.genes_to_mask(), :] = self.X.T
[docs] def select_gene(
self, genes: Union[np.ndarray, list], arr: Optional[np.ndarray] = None, axis: Optional[int] = None,
) -> np.ndarray:
"""Selects the gene expression data for the specified genes.
Args:
genes: The genes to select the expression data for.
arr: The array to select the genes from. Defaults to None.
axis: The axis to select the genes along. Defaults to None.
Returns:
The gene expression data for the specified genes.
"""
if arr is None:
arr = self.X
if arr.ndim == 1:
axis = 0
else:
if axis is None:
axis = 1
y = []
if self.genes is not None:
for g in genes:
if g not in self.genes:
LoggerManager.main_logger.warning(f"{g} is not in `self.genes`.")
else:
if axis == 0:
y.append(flatten(arr[self.genes == g]))
elif axis == 1:
y.append(flatten(arr[:, self.genes == g]))
else:
raise Exception("Cannot select genes since `self.genes` is `None`.")
return np.array(y)
def arclength_sampling_n(
X: np.ndarray, num: int, t: Optional[np.ndarray] = None,
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]:
"""Uniformly sample data points on an arc curve that generated from vector field predictions.
Args:
X: The data points to sample from.
num: The number of points to sample.
t: The time values for the data points. Defaults to None.
Returns:
The sampled data points and the arc length of the curve.
"""
arclen = np.cumsum(np.linalg.norm(np.diff(X, axis=0), axis=1))
arclen = np.hstack((0, arclen))
z = np.linspace(arclen[0], arclen[-1], num)
X_ = interp1d(arclen, X, axis=0)(z)
if t is not None:
t_ = interp1d(arclen, t)(z)
return X_, arclen[-1], t_
else:
return X_, arclen[-1]
def remove_redundant_points_trajectory(
X: np.ndarray, tol: float = 1e-4, output_discard: bool = False,
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]:
"""Remove consecutive data points that are too close to each other.
Args:
X: The data points to remove redundant points from.
tol: The tolerance for removing redundant points. Defaults to 1e-4.
output_discard: Whether to output the discarded points. Defaults to False.
Returns:
The data points with redundant points removed and the arc length of the curve.
"""
X = np.atleast_2d(X)
discard = np.zeros(len(X), dtype=bool)
if X.shape[0] > 1:
for i in range(len(X) - 1):
dist = np.linalg.norm(X[i + 1] - X[i])
if dist < tol:
discard[i + 1] = True
X = X[~discard]
arclength = 0
x0 = X[0]
for i in range(1, len(X)):
tangent = X[i] - x0 if i == 1 else X[i] - X[i - 1]
d = np.linalg.norm(tangent)
arclength += d
if output_discard:
return (X, arclength, discard)
else:
return (X, arclength)
def arclength_sampling(X: np.ndarray, step_length: float, n_steps: int, t: Optional[np.ndarray] = None) -> np.ndarray:
"""Uniformly sample data points on an arc curve that generated from vector field predictions.
Args:
X: The data points to sample from.
step_length: The length of each step.
n_steps: The number of steps to sample.
t: The time values for the data points. Defaults to None.
Returns:
The sampled data points and the arc length of the curve.
"""
Y = []
x0 = X[0]
T = [] if t is not None else None
t0 = t[0] if t is not None else None
i = 1
terminate = False
arclength = 0
def _calculate_new_point():
x = x0 if j == i else X[j - 1]
cur_y = x + (step_length - L) * tangent / d
if t is not None:
cur_tau = t0 if j == i else t[j - 1]
cur_tau += (step_length - L) / d * (t[j] - cur_tau)
T.append(cur_tau)
else:
cur_tau = None
Y.append(cur_y)
return cur_y, cur_tau
while i < len(X) - 1 and not terminate:
L = 0
for j in range(i, len(X)):
tangent = X[j] - x0 if j == i else X[j] - X[j - 1]
d = np.linalg.norm(tangent)
if L + d >= step_length:
y, tau = _calculate_new_point()
t0 = tau if t is not None else None
x0 = y
i = j
break
else:
L += d
if j == len(X) - 1:
i += 1
arclength += step_length
if L + d < step_length:
terminate = True
if len(Y) < n_steps:
_, _ = _calculate_new_point()
if T is not None:
return np.array(Y), arclength, T
else:
return np.array(Y), arclength