Source code for cmrsim.trajectory._base

"""Contains base-implementation / definition for trajectory modules"""

__all__ = ["BaseTrajectoryModule", "StaticParticlesTrajectory"]

from abc import abstractmethod

import tensorflow as tf


[docs] class BaseTrajectoryModule(tf.Module): """ Base implementation for Trajectory modules compatible with the Bloch simulation module as well as containing a guaranteed method for position pre-calculation on call. All derived classes must implement the abstract method "increment_particles" which must be compatible with a tf.function decoration. The __call__ function must also be implemented but is not meant to be called inside a tf.function """
[docs] @abstractmethod def __call__(self, initial_positions: tf.Tensor, timing: tf.Tensor, **kwargs)\ -> (tf.Tensor, dict): """ Evaluates the positions for particles at given initial positions for all times specified in the timing argument. For implementations using the increment_particles function in a loop the maximal time-delta must be specified. :param initial_positions: (N, 3) :param timing: (T, ) :param kwargs: Can vary in concrete implementation :return: - r_new [tf.Tensor, (T, N, 3)] - additional_fields [dict] containing the lookup values for each step """ return None
[docs] @abstractmethod def increment_particles(self, particle_positions: tf.Tensor, dt: tf.Tensor, **kwargs) -> (tf.Tensor, dict): """ Evaluates the new position of particles at given locations r for a temporal step width dt. If the concrete implementation involves a look up (e.g. velocity fields) the values at the old location is also returned as dictionary. .. note:: concrete implementations must be compatible with tf.function decoration :param particle_positions: (N, 3) Current particle positions. :param dt: (,) Temporal step width in milliseconds to evaluate the next positions :param kwargs: Can vary in concrete implementation :return: r_new [tf.Tensor, (N, 3)], additional_fields [dict] containing the lookup values """ return None
# pylint: disable=abstract-method
[docs] class StaticParticlesTrajectory(BaseTrajectoryModule): """ Trivial implementation for static particles to match the trajectory-module definition for Bloch simulations. When called (also increment_particles) just returns the identity operation for particle positions along with an empty dictionary usually containing the additional field look-ups. """ def __int__(self): super().__init__(name="static_trajectory")
[docs] def __call__(self, initial_positions: tf.Tensor, timing: tf.Tensor, **kwargs) -> (tf.Tensor, dict): """ Returns a tiled tensor of the static positions :param initial_position: (N, 3) :param timing: (T, ) :return: r_const - (T, N, 3) """ n_steps = tf.shape(timing)[0] return tf.tile(initial_positions[:, tf.newaxis], [1, n_steps, 1]), {}
[docs] def increment_particles(self, particle_positions: tf.Tensor, dt: tf.Tensor, **kwargs) -> (tf.Tensor, dict): return particle_positions, {}