Source code for dynamo.prediction.trajectory

from typing import Callable, List, Tuple, Union

import numpy as np
from scipy.interpolate import interp1d

from ..dynamo_logger import LoggerManager
from ..tools.utils import flatten
from ..utils import expr_to_pca, pca_to_expr
from ..vectorfield.scVectorField import DifferentiableVectorField
from ..vectorfield.utils import angle, normalize_vectors
from .utils import arclength_sampling_n


[docs]class Trajectory: def __init__(self, X: np.ndarray, t: Union[None, np.ndarray] = None, sort: bool = True) -> None: """ Base class for handling trajectory interpolation, resampling, etc. Args: X: trajectory positions, shape (n_points, n_dimensions) t: trajectory times, shape (n_points,). Defaults to None. sort: whether to sort the time stamps. Defaults to True. """ self.X = X if t is None: self.t = None else: self.set_time(t, sort=sort) def __len__(self) -> int: """Returns the number of points in the trajectory. Returns: number of points in the trajectory """ return self.X.shape[0]
[docs] def set_time(self, t: np.ndarray, sort: bool = True) -> None: """ Set the time stamps for the trajectory. Sorts the time stamps if requested. Args: t: trajectory times, shape (n_points,) sort: whether to sort the time stamps. Defaults to True. """ if sort: I = np.argsort(t) self.t = t[I] self.X = self.X[I] else: self.t = t
[docs] def dim(self) -> int: """ Returns the number of dimensions in the trajectory. Returns: number of dimensions in the trajectory """ return self.X.shape[1]
[docs] def calc_tangent(self, normalize: bool = True): """ Calculate the tangent vectors of the trajectory. Args: normalize: whether to normalize the tangent vectors. Defaults to True. Returns: tangent vectors of the trajectory, shape (n_points-1, n_dimensions) """ tvec = self.X[1:] - self.X[:-1] if normalize: tvec = normalize_vectors(tvec) return tvec
[docs] def calc_arclength(self) -> float: """ Calculate the arc length of the trajectory. Returns: arc length of the trajectory """ tvec = self.calc_tangent(normalize=False) norms = np.linalg.norm(tvec, axis=1) return np.sum(norms)
[docs] def calc_curvature(self) -> np.ndarray: """ Calculate the curvature of the trajectory. Returns: curvature of the trajectory, shape (n_points,) """ tvec = self.calc_tangent(normalize=False) kappa = np.zeros(self.X.shape[0]) for i in range(1, self.X.shape[0] - 1): # ref: http://www.cs.jhu.edu/~misha/Fall09/1-curves.pdf (p. 55) kappa[i] = angle(tvec[i - 1], tvec[i]) / (np.linalg.norm(tvec[i - 1] / 2) + np.linalg.norm(tvec[i] / 2)) return kappa
[docs] def resample(self, n_points: int, tol: float = 1e-4, inplace: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Resample the curve with the specified number of points. Args: n_points: An integer specifying the number of points in the resampled curve. tol: A float specifying the tolerance for removing redundant points. Default is 1e-4. inplace: A boolean flag indicating whether to modify the curve object in place. Default is True. Returns: A tuple containing the resampled curve coordinates and time values (if available). Raises: ValueError: If the specified number of points is less than 2. TODO: Decide whether the tol argument should be included or not during the code refactoring and optimization. """ # remove redundant points """if tol is not None: X, arclen, discard = remove_redundant_points_trajectory(self.X, tol=tol, output_discard=True) if self.t is not None: t = np.array(self.t[~discard], copy=True) else: t = None else: X = np.array(self.X, copy=True) t = np.array(self.t, copy=True) if self.t is not None else None arclen = self.calc_arclength()""" # resample using the arclength sampling # ret = arclength_sampling(X, arclen / n_points, t=t) ret = arclength_sampling_n(self.X, n_points, t=self.t) X = ret[0] if self.t is not None: t = ret[2] if inplace: self.X, self.t = X, t return X, t
[docs] def interpolate(self, t: np.ndarray, **interp_kwargs) -> np.ndarray: """Interpolate the curve at new time values. Args: t: The new time values at which to interpolate the curve. **interp_kwargs: Additional arguments to pass to `scipy.interpolate.interp1d`. Returns: The interpolated values of the curve at the specified time values. Raises: Exception: If `self.t` is `None`, which is needed for interpolation. """ if self.t is None: raise Exception("`self.t` is `None`, which is needed for interpolation.") return interp1d(self.t, self.X, axis=0, **interp_kwargs)(t)
[docs] def interp_t(self, num: int = 100) -> np.ndarray: """Interpolates the `t` parameter linearly. Args: num: Number of interpolation points. Returns: The array of interpolated `t` values. """ if self.t is None: raise Exception("`self.t` is `None`, which is needed for interpolation.") return np.linspace(self.t[0], self.t[-1], num=num)
[docs] def interp_X(self, num: int = 100, **interp_kwargs) -> np.ndarray: """ Interpolates the curve at `num` equally spaced points in `t`. Args: num: The number of points to interpolate the curve at. **interp_kwargs: Additional keyword arguments to pass to `scipy.interpolate.interp1d`. Returns: The interpolated curve at `num` equally spaced points in `t`. """ if self.t is None: raise Exception("`self.t` is `None`, which is needed for interpolation.") return self.interpolate(self.interp_t(num=num), **interp_kwargs)
[docs] def integrate(self, func: Callable) -> np.ndarray: """Calculate the integral of a function along the curve. Args: func: A function to integrate along the curve. Returns: The integral of func along the discrete curve. """ F = np.zeros(func(self.X[0]).shape) tvec = self.calc_tangent(normalize=False) for i in range(1, self.X.shape[0] - 1): # ref: http://www.cs.jhu.edu/~misha/Fall09/1-curves.pdf P. 47 F += func(self.X[i]) * (np.linalg.norm(tvec[i - 1]) + np.linalg.norm(tvec[i])) / 2 return F
[docs] def calc_msd(self, decomp_dim: bool = True, ref: int = 0) -> Union[float, np.ndarray]: """Calculate the mean squared displacement (MSD) of the curve with respect to a reference point. Args: decomp_dim: If True, return the MSD of each dimension separately. If False, return the total MSD. ref: Index of the reference point. Default is 0. Returns: The MSD of the curve with respect to the reference point. """ S = (self.X - self.X[ref]) ** 2 if decomp_dim: S = S.sum(axis=0) else: S = S.sum() S /= len(self) return S
class VectorFieldTrajectory(Trajectory): def __init__(self, X: np.ndarray, t: np.ndarray, vecfld: DifferentiableVectorField) -> None: """ Initializes a VectorFieldTrajectory object. Args: X: The trajectory data as a numpy array of shape (n, d). t: The time data as a numpy array of shape (n,). vecfld: The differentiable vector field that describes the trajectory. """ super().__init__(X, t=t) self.vecfld = vecfld self.data = {"velocity": None, "acceleration": None, "curvature": None, "divergence": None} self.Js = None def get_velocities(self) -> np.ndarray: """ Calculates and returns the velocities along the trajectory. Returns: The velocity data as a numpy array of shape (n, d). """ if self.data["velocity"] is None: self.data["velocity"] = self.vecfld.func(self.X) return self.data["velocity"] def get_jacobians(self, method=None) -> np.ndarray: """ Calculates and returns the Jacobians of the vector field along the trajectory. Args: method: The method used to compute the Jacobians. Returns: The Jacobian data as a numpy array of shape (n, d, d). """ if self.Js is None: fjac = self.vecfld.get_Jacobian(method=method) self.Js = fjac(self.X) return self.Js def get_accelerations(self, method=None, **kwargs) -> np.ndarray: """ Calculates and returns the accelerations along the trajectory. Args: method: The method used to compute the Jacobians. **kwargs: Additional keyword arguments to be passed to the acceleration computation method. Returns: The acceleration data as a numpy array of shape (n, d). """ if self.data["acceleration"] is None: if self.Js is None: self.Js = self.get_jacobians(method=method) self.data["acceleration"] = self.vecfld.compute_acceleration(self.X, Js=self.Js, **kwargs) return self.data["acceleration"] def get_curvatures(self, method=None, **kwargs) -> np.ndarray: """ Calculates and returns the curvatures along the trajectory. Args: method: The method used to compute the Jacobians. **kwargs: Additional keyword arguments to be passed to the curvature computation method. Returns: The curvature data as a numpy array of shape (n,). """ if self.data["curvature"] is None: if self.Js is None: self.Js = self.get_jacobians(method=method) self.data["curvature"] = self.vecfld.compute_curvature(self.X, Js=self.Js, **kwargs) return self.data["curvature"] def get_divergences(self, method=None, **kwargs) -> np.ndarray: """ Calculates and returns the divergences along the trajectory. Args: method: The method used to compute the Jacobians. **kwargs: Additional keyword arguments to be passed to the divergence computation method. Returns: The divergence data as a numpy array of shape (n,). """ if self.data["divergence"] is None: if self.Js is None: self.Js = self.get_jacobians(method=method) self.data["divergence"] = self.vecfld.compute_divergence(self.X, Js=self.Js, **kwargs) return self.data["divergence"] def calc_vector_msd(self, key: str, decomp_dim: bool = True, ref: int = 0) -> Union[np.ndarray, float]: """Calculate and return the mean squared displacement of a given vector field attribute in the trajectory. Args: key: The key for the vector field attribute in self.data to compute the mean squared displacement of. decomp_dim: Whether to decompose the MSD by dimension. Defaults to True. ref: The index of the reference point to use for computing the MSD. Defaults to 0. Returns: The mean squared displacement of the specified vector component in the trajectory. TODO: Discuss should we also calculate other quantities during the code refactoring and optimization phase (e.g. curl, hessian, laplacian, etc). """ V = self.data[key] S = (V - V[ref]) ** 2 if decomp_dim: S = S.sum(axis=0) else: S = S.sum() S /= len(self) return S
[docs]class GeneTrajectory(Trajectory): def __init__( self, adata, X=None, t=None, X_pca=None, PCs="PCs", mean="pca_mean", genes="use_for_pca", expr_func=None, **kwargs, ) -> None: """ This class is not fully functional yet. """ self.adata = adata if type(PCs) is str: PCs = self.adata.uns[PCs] self.PCs = PCs if type(mean) is str: mean = self.adata.uns[mean] self.mean = mean self.expr_func = expr_func if type(genes) is str: genes = adata.var_names[adata.var[genes]].to_list() self.genes = np.array(genes) if X_pca is not None: self.from_pca(X_pca, t=t, **kwargs) if X is not None: super().__init__(X, t=t) def from_pca(self, X_pca, t=None): X = pca_to_expr(X_pca, self.PCs, mean=self.mean, func=self.expr_func) super().__init__(X, t=t) def to_pca(self, x=None): if x is None: x = self.X return expr_to_pca(x, self.PCs, mean=self.mean, func=self.expr_func) def genes_to_mask(self): mask = np.zeros(self.adata.n_vars, dtype=np.bool_) for g in self.genes: mask[self.adata.var_names == g] = True return mask
[docs] def calc_msd(self, save_key="traj_msd", **kwargs): msd = super().calc_msd(**kwargs) LoggerManager.main_logger.info_insert_adata(save_key, "var") self.adata.var[save_key] = np.ones(self.adata.n_vars) * np.nan self.adata.var[save_key][self.genes_to_mask()] = msd return msd
def save(self, save_key="gene_trajectory"): LoggerManager.main_logger.info_insert_adata(save_key, "varm") self.adata.varm[save_key] = np.ones((self.adata.n_vars, self.X.shape[0])) * np.nan self.adata.varm[save_key][self.genes_to_mask(), :] = self.X.T def select_gene(self, genes, arr=None, axis=None): if arr is None: arr = self.X if arr.ndim == 1: axis = 0 else: if axis is None: axis = 1 y = [] if self.genes is not None: for g in genes: if g not in self.genes: LoggerManager.main_logger.warning(f"{g} is not in `self.genes`.") else: if axis == 0: y.append(flatten(arr[self.genes == g])) elif axis == 1: y.append(flatten(arr[:, self.genes == g])) else: raise Exception("Cannot select genes since `self.genes` is `None`.") return np.array(y)