Source code for cmrsim.trajectory._taylor

"""Contains the implementation of a module that fits a taylor expansion to particle trajectories"""
__all__ = ["TaylorTrajectoryN"]

import numpy as np
import tensorflow as tf

from cmrsim.trajectory._base import BaseTrajectoryModule


# pylint: disable=abstract-method
[docs] class TaylorTrajectoryN(BaseTrajectoryModule): """Fits a taylor Polynomial of specified order to the given 3D particle trajectories and stores the resulting coefficients per particle. When called, evaluates the Taylor-expansion at given timing in a tf.function compatible definition. Incrementing particle positions is done by keeping track of the current timing. Batching the particles for all evaluations is done by setting the attributes self.batch_size and self.current_batch_idx. This results in the indexing: [self.batch_size*self.current_batch_idx : self.batch_size*self.current_batch_idx+1] .. dropdown:: Example Usage .. code-block:: python :caption: Instantiation ref_timing = ... # shape (T, ) ref_trajectory = ... # shape (N, T, dims) module = TaylorTrajectoryN(order=3, time_grid=ref_timing, particle_trajectories=ref_trajectory) :param order: Order of the fitted TaylorPolynomial :param time_grid: (#timesteps, ) :param particle_trajectories: (#particles, #timesteps, 3) :param batch_size: used for evaluating the particle trajectories in batches :param fit_on_init: If True, the Polynomial is fitted on instantiation of the module. """ #: Keeps track of the current timing when increment_particles is called. current_time_ms: tf.Variable #: Allows to only evaluate the position for a batch of stored particle trajectories current_batch_idx: tf.Variable #: Together with self.current_batch_size determines the subset of particle trajectories that #: is evaluated on call and increment_particles batch_size: tf.Variable #: Stores the order of the TaylorPolynomial, defined on instantiation order: tf.Variable #: Stores the result of fitting the TaylorPolynomial for all particle trajectories optimal_parameters: tf.Variable #: Is periodic is_periodic: tf.constant #: ref_time: tf.Variable #: end_time: tf.Variable # pylint: disable=too-many-arguments def __init__(self, order: int, time_grid: np.ndarray, particle_trajectories: np.ndarray, batch_size: int = None, fit_on_init: bool = True, is_periodic: bool = False): """ :param order: Order of the fitted TaylorPolynomial :param time_grid: (#timesteps, ) :param particle_trajectories: (#particles, #timesteps, 3) :param batch_size: used for evaluating the particle trajectories in batches :param fit_on_init: If True, the Polynomial is fitted on instantiation of the module. """ if batch_size is not None: self.batch_size = tf.Variable(batch_size, dtype=tf.int32, shape=(), trainable=False) else: self.batch_size = tf.Variable(particle_trajectories.shape[0], dtype=tf.int32, shape=(), trainable=False) self.ref_time = tf.Variable(time_grid[0].astype(np.float32), dtype=tf.float32, shape=(), trainable=False) self.end_time = tf.Variable(time_grid[-1].astype(np.float32), dtype=tf.float32, shape=(), trainable=False) self.current_batch_idx = tf.Variable(0, dtype=tf.int32, shape=(), trainable=False) self.current_time_ms = tf.Variable(self.ref_time, dtype=tf.float32, shape=(), trainable=False) self.order = tf.Variable(order, dtype=tf.int32, shape=(), trainable=False) self._int_order = order self.optimal_parameters = tf.Variable( tf.zeros((particle_trajectories.shape[0], order + 1, 3), dtype=tf.float32), dtype=tf.float32, shape=(None, None, None), trainable=False) super().__init__(name=f"taylor_trajectory_order{order}") if fit_on_init: t_zero_ref = time_grid - self.ref_time self.fit(t_zero_ref, particle_trajectories) self.is_periodic = tf.constant(is_periodic, dtype=tf.bool)
[docs] def fit(self, t_grid: np.ndarray, particle_trajectories: np.ndarray): """ Fits a Taylor polynomial of order self.order to each particle trajectory. :param t_grid: (T, ) times corresponding to the particle positions :param particle_trajectories: (N, T, dim) :return: optimal motion parameters (N, order, dim) containing all parameters (r0, v, a, j, ...) for N particles """ n_particles, n_steps, n_dims = particle_trajectories.shape flattened_trajectories = np.swapaxes(particle_trajectories, 0, 1).reshape(n_steps, -1) flat_coefficients = np.polynomial.polynomial.polyfit(t_grid, flattened_trajectories, deg=self._int_order) coefficients = np.swapaxes( flat_coefficients.reshape(self._int_order + 1, n_particles, n_dims), 0, 1) self.optimal_parameters.assign(coefficients.astype(np.float32))
@tf.function def _evaluate_trajectory(self, t: tf.Tensor) -> tf.Tensor: """ Evaluates the taylor expansion for the current batch of particles at the specified times t. :param t: (#timesteps) :return: (#particles, #timesteps, 3) """ t = t - self.ref_time if self.is_periodic: t = tf.math.floormod(t, self.end_time - self.ref_time) batch_start = self.current_batch_idx * self.batch_size batch_end = batch_start + self.batch_size factors = self.optimal_parameters[batch_start:batch_end] exponents = tf.range(0, tf.cast(self.order + 1, dtype=tf.float32))[:, tf.newaxis] t_pow_n = t[np.newaxis] ** exponents # (order, time) out = tf.reduce_sum(factors[:, :, tf.newaxis] * t_pow_n[tf.newaxis, :, :, tf.newaxis], axis=1) return out
[docs] def __call__(self, initial_positions: tf.Tensor, timing: tf.Tensor, batch_index: int = 0, **kwargs) -> (tf.Tensor, dict): """ Evaluates the taylor expansion for the current batch of particles at the specified times t. :param timing: (#timesteps) in milliseconds :return: (#particles, #timesteps, 3) in meter """ idx_before = self.current_batch_idx.read_value() self.current_batch_idx.assign(batch_index) pos = self._evaluate_trajectory(timing) self.current_batch_idx.assign(idx_before) return pos, {}
[docs] @tf.function def increment_particles(self, particle_positions: tf.Tensor, dt: tf.Tensor, **kwargs) -> (tf.Tensor, dict): """ Evaluates the particle position for the time self.current_time_ms + dt and adds the delta t to the current_time_ms variable :param r: unused parameter (to adhere to calling signature of trajectory modules) :param dt: temporal step lengths :param kwargs: unused parameter (to adhere to calling signature of trajectory modules) :return: (#batch, 3) """ self.current_time_ms.assign_add(dt) new_pos = self._evaluate_trajectory(tf.reshape(self.current_time_ms, [1, ])) return new_pos[:, 0], {}