""" This modules contains signal model based signal modification due to bulk-motion """
__all__ = ["PhaseTracking"]
import tensorflow as tf
from cmrsim.analytic.contrast.base import BaseSignalModel
# pylint: disable=abstract-method
[docs]
class PhaseTracking(BaseSignalModel):
# pylint: disable=anomalous-backslash-in-string
""" Evalulates Phase accrual for pre-defined particle trajectories
Phase calculation for a simple spatially constant gradient wave form and for
moving particles according to:
.. math::
\phi(t) = \int \\vec{G}(t) \\cdot \\vec{r}(t) dt
:param wave_form: (#repetition, #steps, [t, x, y, z]) - time and wave-form-grid
:param gamma: Gyromagnetic ratio in rad/ms/mT
"""
required_quantities = ('trajectory', )
#: (#repetitions, #time-steps, 4[t, x, y, z])
wave_form: tf.Variable
#: gyromagnetic ratio in rad/mT/ms
gamma: tf.Tensor
def __init__(self, expand_repetitions: bool, wave_form: tf.Tensor,
gamma: float, **kwargs):
# pylint: disable=anomalous-backslash-in-string
"""
:param wave_form: (#repetition, #steps, [t, x, y, z]) - time and wave-form-grid
:param gamma: Gyromagnetic ratio in rad/ms/mT
"""
super().__init__(expand_repetitions=expand_repetitions, name="phase_tracking", **kwargs)
with tf.device(self.device):
self.wave_form = tf.Variable(wave_form, shape=(None, None, 4),
dtype=tf.float32, name='wave_form')
self.n_steps = tf.Variable(0, shape=(), dtype=tf.int64, name="n_steps")
self.gamma = tf.constant(gamma, dtype=tf.float32)
self.expansion_factor = tf.shape(self.wave_form)[0]
self.update()
# @tf.function(jit_compile=True)
[docs]
def __call__(self, signal_tensor: tf.Tensor, trajectory: tf.Tensor, **kwargs) -> tf.Tensor:
""" Evaluates the phase integration for given trajectories
:param signal_tensor: (#voxel, #repetitions, #k-space-samples) last two dimensions
can be 1
:param trajectory: (#voxel, #repetitions, #k-space-samples, #time-steps, 3)
last dimension corresponds to (x, y, z) in meter
#time-steps must be the same as self.wave_form
:return: - if expand_repetitions == True: (#voxel, #repetitions \* self.wave_form.shape[0],
#k-space-samples)
- if expand_repetitions == False: (#voxel, #repetitions, #k-space-samples)
"""
with tf.device(self.device):
tf.Assert(tf.shape(trajectory)[-2] == tf.shape(self.wave_form.read_value())[1],
["[PhaseTracking] Wave-form and trajectory grid dont match!"])
input_shape = tf.shape(signal_tensor)
k_space_axis_tiling = tf.cast((tf.floor(input_shape[2]/tf.shape(trajectory)[2])),
tf.int32)
# trapezoidal integration
tf.Assert(tf.shape(trajectory)[1] == 1 or
tf.shape(trajectory)[1] == self.expansion_factor,
["Shape missmatch for input argument trajectory"])
delta_t = self.wave_form[:, 1:, 0] - self.wave_form[:, 0:-1, 0]
g_dot_r = tf.einsum('...ti, b...kti -> b...kt', self.wave_form[:, :, 1:], trajectory)
phases = tf.reduce_sum(delta_t[tf.newaxis, :, tf.newaxis]
* (g_dot_r[..., 1:] + g_dot_r[..., :-1]) / 2,
axis=-1) * self.gamma
complex_phase = tf.exp(tf.complex(0., phases))
complex_phase = tf.tile(complex_phase, [1, 1, k_space_axis_tiling])
# Case 1: expand-dimensions
if self.expand_repetitions or self.expansion_factor == 1:
temp = tf.einsum('vrk, vnk -> vrnk', signal_tensor, complex_phase)
result = tf.reshape(temp, (input_shape[0], -1, input_shape[2]))
# Case 2: repetition-axis of signal_tensor must match self.expansion_factor
else:
tf.Assert(input_shape[1] == self.expansion_factor,
["Shape missmatch for input argument signal_tensor for case no-expand",
input_shape, self.expansion_factor])
result = signal_tensor * complex_phase
return result
[docs]
def update(self):
super().update()
self.n_steps.assign(self.wave_form.read_value().shape[1])
self.expansion_factor = tf.shape(self.wave_form)[0]