Source code for cmrsim.analytic._composite_signal

""" This module contains the base implementation for all signal-process building blocks. """
__all__ = ["CompositeSignalModel"]

from typing import Union, Tuple, Optional

import tensorflow as tf
import numpy as np

from cmrsim.analytic.contrast import BaseSignalModel

[docs] class CompositeSignalModel(tf.Module): """ Class that sequentially calls the signal-process building blocks. This offers a general way to stack and exchange arbitrary building blocks without changing the simulation calling signature :param args: All signal modules that need to be concatenated. """ required_quantities: Tuple[str] = ('M0',) _sub_module_order: Tuple[str] = () _expansion_factors: Tuple[int] = () _expansion_names: Tuple[str] = () _apply_expansion: Tuple[bool] = () def __init__(self, *args): """ Container module to concatenate and sequentially call subclasses of BaseSignalModules. Calling order is equal to order arguments. :param args: All signal modules that need to be concatenated. """ super().__init__(name="composite_signal_model") with self.name_scope: for module in args: if isinstance(module, BaseSignalModel): self.__dict__[] = module self.required_quantities += module.required_quantities self._sub_module_order += (,) if module.expansion_factor is None or module.expansion_factor < 1: raise NotImplementedError(f"The module {} seems to not have set" f" the expansion factor for the #repetition-axis." f" If the module does not increase the number of " f"repetitions make sure it is set to 1! Currently" f" it is: {module.expansion_factor}") self._expansion_factors += (module.expansion_factor,) self._apply_expansion += (module.expand_repetitions, ) self._expansion_names += (,) self.required_quantities = tuple(dict.fromkeys(self.required_quantities).keys())
[docs] def __call__(self, signal_tensor: tf.Tensor, segment_index: Optional[Union[int, tf.Tensor]] = 0, **kwargs): """ Consecutively calls submodules which the signal change to the passed in tensor. :param kwargs: dictionary of tensors containing all required quantities. Is forwarded to BaseSignalModule. """ kwargs.update({'segment_index': segment_index}) for sub_mod_name in self._sub_module_order: sub_mod = self.__dict__[sub_mod_name] signal_tensor = sub_mod(signal_tensor, **kwargs) return signal_tensor
def __str__(self): string = "CompositeSignalModel: \n\t" string += "\n\t".join([f"{self.__dict__[sbm]}: {sbm}" for sbm in self._sub_module_order]) return string
[docs] def update(self): """ Calls update function of all sub-modules and records the overall expansion factor """ self._expansion_factors = () self._expansion_names = () for module_name in self._sub_module_order: module = self.__dict__[module_name] module.update() self._expansion_factors += (int(module.expansion_factor),) self._expansion_names += (,)
@property def expected_number_of_repetitions(self) -> tf.Tensor: """ Returns the total expansion factor (factor by which the #repetitions axis grows)""" self.update() appl_idx = tf.squeeze(tf.where(tf.stack(self._apply_expansion))) exps = tf.stack(self._expansion_factors) return tf.reduce_prod(tf.gather(exps, appl_idx)) @property def unstacked_axis_names(self) -> Tuple[str]: """ returns the name of modules with expansion factors > 1 in the order that is used in unstack_repetitions: """ names = tuple((str(n) for (n, f, apply) in zip(self._expansion_names, self._expansion_factors, self._apply_expansion) if f > 1 and apply)) return names
[docs] def unstack_repetitions(self, simulation_result: Union[tf.Tensor, 'np.ndarray']) -> tf.Tensor: """ Uses the dimension expansion information from the sub modules to unstack the simulated tensor of k-space, samples or images. In both cases the second axis (index=1) is assumed to represent the stacked repetitions. :param simulation_result: (-1, [noise], samples) :return: k-space (..., [noise], samples) """ self.update() input_shape = tf.shape(simulation_result) unstacked_repetitions_shape = tuple( (f for (f, apply) in zip(self._expansion_factors[::-1], self._apply_expansion[::-1]) if (f > 1 and apply))) resulting_shape = tf.concat((unstacked_repetitions_shape, input_shape[1:]), 0) return tf.reshape(simulation_result, resulting_shape)