__all__ = ["SliceProfile", "LocalLookREST"]
from typing import Union, Tuple, Sequence
import tensorflow as tf
import numpy as np
from cmrsim.analytic.contrast.base import BaseSignalModel
[docs]
class SliceProfile(BaseSignalModel):
"""Simplified slice-selection module, computing a weighting from 0 to 1 depending on through-slice position.
.. note::
To modify the actual slice-profile you can implement a subclass of this class that implements a different
'slice_profile' method.
.. dropdown:: Example Images
.. image:: _static/api/analytic/contrast_slice_profile.png
.. dropdown:: Example Usage
.. code::
uniform_mesh = pyvista.UniformGrid(spacing=(0.001, 0.001, 0.001), dimensions=(100, 100, 100),
origin=(-0.05, -0.05, -0.05))
r_vectors = tf.constant(uniform_mesh.points, dtype=tf.float32)
slice_normal = np.array([[0, 1, 1], [0, 0, 1], [1, 2, 0]], dtype=np.float64)
slice_normal /= np.linalg.norm(slice_normal, keepdims=True, axis=-1)
slice_position = Quantity([[0, 1, 1], [1, 0, 0], [0, 1, 0]], "cm").m_as("m")
slice_thickness = Quantity([2, 4, 0.5], "cm").m_as("m")
slice_mod = cmrsim.analytic.contrast.SliceProfile(expand_repetitions=True, slice_normal=slice_normal,
slice_position=slice_position,
slice_thickness=slice_thickness)
r_vectors_excitation = tf.reshape(r_vectors, [-1, 1, 1, 3])
signal_start = tf.ones(r_vectors_excitation.shape[:-1], dtype=tf.complex64)
signal_out = slice_mod(signal_tensor=signal_start, r_vectors_excitation=r_vectors_excitation).numpy()
for i in range(signal_out.shape[1]):
uniform_mesh[f"signal_out{i}"] = np.abs(signal_out)[:, i, 0]
:param expand_repetitions: if True, expands repetition axes with the factor determined by the first axes of the
following input arguments
:param slice_normal: (expansion_factor, 3) vector defining the slice normal of the excitation slice
:param slice_position: (expansion_factor, 3) vector determining the slice-center-point
:param slice_thickness:(expansion_factor, ) scalar in meter, determining the slice thickness
:param device:
"""
#: Additional mandatory keyword arguments for call
required_quantities = ('r_vectors_excitation', )
#: (expansion_factor, 3) vector defining the slice normal of the excitation slice
slice_normal: tf.Variable = None
#: (expansion_factor, 3) vector determining the slice-center-point
slice_position: tf.Variable = None
#: (expansion_factor, ) scalar in meter, determining the slice thickness
slice_thickness: tf.Variable = None
def __init__(self, expand_repetitions: bool, slice_normal: Sequence[Union[float, Sequence[float]]],
slice_position: Sequence[Union[float, Sequence[float]]],
slice_thickness: Union[float, Sequence[float]], device: str = None):
slice_normal = tf.reshape(tf.constant(slice_normal, dtype=tf.float32), (-1, 3))
slice_position = tf.reshape(tf.constant(slice_position, dtype=tf.float32), (-1, 3))
slice_thickness = tf.reshape(tf.constant(slice_thickness, dtype=tf.float32), (-1))
self._validate_shapes(slice_normal, slice_position, slice_thickness)
super(SliceProfile, self).__init__(name="slice_excitation", expand_repetitions=expand_repetitions,
device=device)
self.slice_normal = tf.Variable(slice_normal, shape=(None, 3), dtype=tf.float32)
self.slice_position = tf.Variable(slice_position, shape=(None, 3), dtype=tf.float32)
self.slice_thickness = tf.Variable(slice_thickness, shape=(None,), dtype=tf.float32)
self.update()
@staticmethod
def _validate_shapes(slice_normal, slice_position, slice_thickness):
"""Checks if the number of normals/positions/thickness is the same
:raises: Value error if len(...) of all input arguments are not equal
"""
if not ((len(slice_normal) == len(slice_position)) and (len(slice_normal) == len(slice_thickness))):
raise ValueError("Must specify the same number of slice normals/positions/thickness but got: "
f"{[len(s) for s in (slice_normal, slice_position, slice_thickness)]}")
[docs]
def update(self):
self._validate_shapes(self.slice_normal.read_value(), self.slice_position.read_value(),
self.slice_thickness.read_value())
self.expansion_factor.assign(tf.shape(self.slice_normal.read_value())[0])
[docs]
def __call__(self, signal_tensor: tf.Tensor, r_vectors_excitation: tf.Tensor, **kwargs): # noqa
""" Call function for analytice slice-profile weighting
:raises: AssertionError - r_vectors_excitation.shape[1] not equal to 1 or self.expansion_factor
- r_vectors_excitation.shape[2] is not equal to 1 (k-samples not supported here)
:param signal_tensor: (#batch, [#repetitions, 1], #ksamples)
:param r_vectors_excitation: (#batch, [#repetition, 1], 1, 3)
:return: signal_tensor weighted by slice selective excitation
"""
with tf.device(self.device):
# All Cases --> repetitions-axis of argument r_vectors must be either 1 or equal to self.expansion_factor
tf.Assert(tf.shape(r_vectors_excitation)[1] == 1 or
tf.shape(r_vectors_excitation)[1] == self.expansion_factor,
["Shape missmatch for input argument r_vectors_excitation in SliceProfile!"
" Repetions axis must match expansion factor or 1, but got", tf.shape(r_vectors_excitation),
self.expansion_factor])
tf.Assert(tf.shape(r_vectors_excitation)[2] == 1,
["Shape missmatch for input argument r_vectors_excitation in SliceProfile!"
" k-space sample axis must be equal to 1, but got: ", tf.shape(r_vectors_excitation)])
input_shape = tf.shape(signal_tensor)
relative_coords = r_vectors_excitation - tf.reshape(self.slice_position, [1, -1, 1, 3])
s_coords = tf.einsum("vrki, ri -> vrk", relative_coords, self.slice_normal)
profile_factors = self.slice_profile(s_coords)
profile_factors = tf.complex(profile_factors, tf.zeros_like(profile_factors))
# Case 1: expand-dimensions
if self.expand_repetitions or self.expansion_factor == 1:
temp = tf.einsum('vrk, vek -> vrek', signal_tensor, profile_factors)
result = tf.reshape(temp, (input_shape[0], -1, input_shape[2]))
else: # Case 2: repetition-axis of signal_tensor must match self.expansion_factor
tf.Assert(input_shape[1] == self.expansion_factor,
["Shape missmatch for input argument signal_tensor for case no-expand in SliceProfile! "
"Expected repetitions axis == self.expansion factor but got: ", input_shape, " | ",
self.expansion_factor])
result = tf.einsum('vrk, vrk -> vrk', signal_tensor, profile_factors)
return result
[docs]
def slice_profile(self, s_coord: tf.Tensor) -> tf.Tensor:
"""
:param s_coord: arbitrary shaped tensor containing the values for through-slice coordinate relative to slice
position
:return: factor between 0, 1 for each position (of stame shape as s_coord)
"""
return tf.where(tf.abs(s_coord) < tf.reshape(self.slice_thickness, [1, -1, 1]) / 2,
tf.ones_like(s_coord), tf.zeros_like(s_coord))
[docs]
class LocalLookREST(BaseSignalModel):
"""Simplified Local Look with REST slabs module, computing a weighting from 0 to 1 depending on
MPS position (Box-selective excitation).
.. note::
To modify the actual profile per M/P/S direction you can create a subclass of this class that
implements a different 'box_profile' method.
.. dropdown:: Example Images
.. image:: _static/api/analytic/contrast_box_profile.png
.. dropdown:: Example Usage
uniform_mesh = pyvista.UniformGrid(spacing=(0.001, 0.001, 0.001), dimensions=(100, 100, 100), origin=(-0.05, -0.05, -0.05))
r_vectors = tf.constant(uniform_mesh.points, dtype=tf.float32)
slice_normal = np.eye(3, 3)
readouts = np.roll(np.eye(3, 3), 1, axis=0)
phase_encodes = np.roll(np.eye(3, 3), 2, axis=0)
slice_position = Quantity([[0, 1, 1], [1, 0, 0], [0, 0, 1]], "cm").m_as("m")
spatial_extends = Quantity([[3, 2, 0.5], [5, 10, 1], [8, 6, 2]], "cm").m_as("m") * 1.5
rotation_matrices = tf.stack([readouts, phase_encodes, slice_normal], axis=1)
lolo_mod = cmrsim.analytic.contrast.LocalLookREST(expand_repetitions=True, slice_normal=slice_normal,
readout_direction=readouts, phase_encoding_direction=phase_encodes,
slice_position=slice_position, spatial_extends=spatial_extends)
r_vectors_excitation = tf.reshape(r_vectors, [-1, 1, 1, 3])
signal_start = tf.ones(r_vectors_excitation.shape[:-1], dtype=tf.complex64)
signal_out = lolo_mod(signal_tensor=signal_start, r_vectors_excitation=r_vectors_excitation).numpy()
for i in range(signal_out.shape[1]):
uniform_mesh[f"signal_out{i}"] = np.abs(signal_out)[:, i, 0]
:param expand_repetitions: if True, expands repetition axes with the factor determined by the first axes of the
following input arguments
:param slice_normal: (expansion_factor, 3) vector defining the slice normal of the excitation slice
:param readout_direction: (expansion_factor, 3) vector defining the readout direction
:param phase_encoding_direction: (expansion_factor, 3) vector defining the phase encoding direction
:param slice_position: (expansion_factor, 3) vector determining the slice-center-point
:param spatial_extends: (expansion_factor, 3) box-width per M-P-S direction around the slice-position
:param device:
"""
#: Additional mandatory keyword arguments for call
required_quantities = ('r_vectors_excitation', )
#: (expansion_factor, 4, 4) orientation matrix containing the transformation matrix from r_vectors into MPS
#: coordinates. Is composed from slice_normal, phase_encoding and readout directions. matrix per repetition
#: is guaranteed to be orthogonal
orientation_matrix: tf.Variable
#: (expansion_factor, 3) spatial extend per M-P-S directions (conceptually equivalent to slice-thickness)
spatial_extend: tf.Variable
def __init__(self, expand_repetitions: bool,
slice_normal: Sequence[Union[float, Sequence[float]]],
readout_direction: Sequence[Union[float, Sequence[float]]],
phase_encoding_direction: Sequence[Union[float, Sequence[float]]],
slice_position: Sequence[Union[float, Sequence[float]]],
spatial_extends: Sequence[Union[float, Sequence[float]]],
device: str = None):
slice_normal = tf.reshape(tf.constant(slice_normal, dtype=tf.float32), (-1, 3))
readout_direction = tf.reshape(tf.constant(readout_direction, dtype=tf.float32), (-1, 3))
phase_encoding_direction = tf.reshape(tf.constant(phase_encoding_direction, dtype=tf.float32), (-1, 3))
slice_position = tf.reshape(tf.constant(slice_position, dtype=tf.float32), (-1, 3))
spatial_extends = tf.reshape(tf.constant(spatial_extends, dtype=tf.float32), (-1, 3))
self._validate_shapes(slice_normal, readout_direction, phase_encoding_direction,
slice_position, spatial_extends)
super(LocalLookREST, self).__init__(name="lolo_rest", expand_repetitions=expand_repetitions,
device=device)
transformation_matrices = np.zeros([len(slice_normal), 4, 4], dtype=np.float32)
transformation_matrices[:, :3, :3] = np.stack([readout_direction, phase_encoding_direction,
slice_normal], axis=1)
transformation_matrices[:, :3, 3] = slice_position
transformation_matrices[:, 3, 3] = 1.
self.orientation_matrix = tf.Variable(transformation_matrices, shape=(None, 4, 4), dtype=tf.float32)
self.slice_position = tf.Variable(slice_position, shape=(None, 3), dtype=tf.float32)
self.spatial_extend = tf.Variable(spatial_extends, shape=(None, 3), dtype=tf.float32)
self.update()
@staticmethod
def _validate_shapes(slice_normal, readout_direction, phase_encoding_direction, slice_position, spatial_extends):
"""Checks if the number of normals/positions/thickness is the same
:raises: ValueError - if len(...) of all input arguments are not equal
- if MPS directions are not orthonormal per repetition
"""
expansion_factors = [len(arg) for arg in (slice_normal, readout_direction, phase_encoding_direction,
slice_position, spatial_extends)]
if not all([e == expansion_factors[0] for e in expansion_factors]):
raise ValueError("Must specify the same number of slice normals/readouts/phase encode/positions/thickness"
f" but got: {expansion_factors}")
rotation_matrices = tf.stack([readout_direction, phase_encoding_direction, slice_normal], axis=1)
orthonormality = tf.einsum("nij, njk -> nik", rotation_matrices, tf.transpose(rotation_matrices, [0, 2, 1]))
orthonormality_bool = [np.allclose(mat, np.eye(3, 3), atol=1e-3) for mat in orthonormality]
if not all(orthonormality_bool):
raise ValueError("All encoding directions per repetitions must form an orthonormal basis: "
f"{orthonormality_bool} \n {orthonormality}")
[docs]
def update(self):
self._validate_shapes(self.orientation_matrix.read_value()[:, 0],
self.orientation_matrix.read_value()[:, 1],
self.orientation_matrix.read_value()[:, 2],
self.slice_position.read_value(),
self.spatial_extend.read_value())
self.expansion_factor.assign(tf.shape(self.orientation_matrix.read_value())[0])
[docs]
def __call__(self, signal_tensor: tf.Tensor, r_vectors_excitation: tf.Tensor, **kwargs): # noqa
""" Call function for analytice local look with REST slabs weighting
:raises: AssertionError - r_vectors_excitation.shape[1] not equal to 1 or self.expansion_factor
- r_vectors_excitation.shape[2] is not equal to 1 (k-samples not supported here)
:param signal_tensor: (#batch, [#repetitions, 1], #ksamples)
:param r_vectors_excitation: (#batch, [#repetition, 1], 1, 3)
:return: signal_tensor weighted by box-selective excitation
"""
with tf.device(self.device):
# All Cases --> repetitions-axis of argument r_vectors must be either 1 or equal to self.expansion_factor
tf.Assert(tf.shape(r_vectors_excitation)[1] == 1 or
tf.shape(r_vectors_excitation)[1] == self.expansion_factor,
["Shape missmatch for input argument r_vectors_excitation in SliceProfile!"
" Repetions axis must match expansion factor or 1, but got", tf.shape(r_vectors_excitation),
self.expansion_factor])
tf.Assert(tf.shape(r_vectors_excitation)[2] == 1,
["Shape missmatch for input argument r_vectors_excitation in SliceProfile!"
" k-space sample axis must be equal to 1, but got: ", tf.shape(r_vectors_excitation)])
input_shape = tf.shape(signal_tensor)
relative_coords = r_vectors_excitation - tf.reshape(self.slice_position, [1, -1, 1, 3])
augmented_coords = tf.concat([relative_coords, tf.ones_like(relative_coords[..., :1])], axis=-1)
mps_coords = tf.einsum("rij, vrki-> vrkj", self.orientation_matrix, augmented_coords)
profile_factors = self.box_profile(mps_coords)
profile_factors = tf.complex(profile_factors, tf.zeros_like(profile_factors))
# Case 1: expand-dimensions
if self.expand_repetitions or self.expansion_factor == 1:
tf.print(tf.shape(signal_tensor), tf.shape(profile_factors))
temp = tf.einsum('vrk, vek -> vrek', signal_tensor, profile_factors)
result = tf.reshape(temp, (input_shape[0], -1, input_shape[2]))
else: # Case 2: repetition-axis of signal_tensor must match self.expansion_factor
tf.Assert(input_shape[1] == self.expansion_factor,
["Shape missmatch for input argument signal_tensor for case no-expand in SliceProfile! "
"Expected repetitions axis == self.expansion factor but got: ", input_shape, " | ",
self.expansion_factor])
result = tf.einsum('vrk, vrk -> vrk', signal_tensor, profile_factors)
return result
[docs]
def box_profile(self, mps_coords: tf.Tensor) -> tf.Tensor:
"""
:param mps_coords: arbitrary shaped tensor containing the values for MPS coordinate relative to slice position
:return: factor between 0, 1 for each position (of same shape as mps_coords[..., 0])
"""
in_x = tf.abs(mps_coords[..., 0]) < tf.reshape(self.spatial_extend[..., 0], [1, -1, 1]) / 2
in_y = tf.abs(mps_coords[..., 1]) < tf.reshape(self.spatial_extend[..., 1], [1, -1, 1]) / 2
in_z = tf.abs(mps_coords[..., 2]) < tf.reshape(self.spatial_extend[..., 2], [1, -1, 1]) / 2
binary_factors = tf.where(tf.reduce_prod(tf.cast(tf.stack([in_x, in_y, in_z], axis=0), tf.float32), axis=0)> 0.,
tf.ones_like(mps_coords[..., 0]), tf.zeros_like(mps_coords[..., 0]))
return binary_factors