""" This module contains the Entry-point for defining and running simulations """
__all__ = ['AnalyticSimulation', ]
from typing import Tuple, Optional, TYPE_CHECKING
from abc import abstractmethod
import os
from collections import OrderedDict
import tensorflow as tf
if TYPE_CHECKING:
from cmrsim.datasets import AnalyticDataset
from cmrsim.analytic.encoding.base import BaseSampling
from cmrsim.reconstruction.base import BaseRecon
from cmrsim.analytic._composite_signal import CompositeSignalModel
from cmrsim.utils.display import SimulationProgressBarII
from time import perf_counter
[docs]
class AnalyticSimulation(tf.Module):
""" This module provides the entry point to build and call the simulations defined within the
cmrsim framework. It is meant to be subclassed, where all subclasses need to implement the
abstract method `_build` defining the building blocks of the simulation to run.
Creates an instance of class either by using the given building Blocks or by calling
the _build function to set up the simulation.
:param name:
:param building_blocks: Tuple containing one instance each of
'CompositeSignalModel', 'BaseSampling', 'BaseRecon'
defining the actual simulation configuration if it is None, the
memberfunction _build is called
"""
forward_model: 'CompositeSignalModel' = None
encoding_module: 'BaseSampling' = None
recon_module: 'BaseRecon' = None
def __init__(self, name: str = None,
building_blocks: Tuple['CompositeSignalModel', 'BaseSampling', 'BaseRecon'] = None,
build_kwargs: dict = None):
"""
:param name:
:param building_blocks: Tuple containing one instance each of
'CompositeSignalModel', 'BaseSampling', 'BaseRecon'
defining the actual simulation configuration if it is None, the
memberfunction _build is called
"""
super().__init__(name=name)
if building_blocks is None:
if build_kwargs is None:
raise ValueError("Neither building blocks nor build kwargs are specified")
self.forward_model, self.encoding_module, self.recon_module = self.build(**build_kwargs)
else:
self.forward_model, self.encoding_module, self.recon_module = building_blocks
# Initialize progress bar
n_k_space_segments = self.encoding_module.k_space_segments.read_value()
self.progress_bar = SimulationProgressBarII(total_voxels=1, prefix='Run Simulation: ',
total_segments=n_k_space_segments)
[docs]
@abstractmethod
def build(self, **kwargs) -> ('CompositeSignalModel', 'BaseSampling', 'BaseRecon'):
""" Abstract method that needs to be defined in subclasses to configure
specific simulations
"""
return -1
[docs]
def __call__(self, dataset: 'AnalyticDataset', voxel_batch_size: int = 1000,
unstack_repetitions: bool = True,
trajectory_module=None,
trajectory_signatures: dict = None,
additional_kwargs: dict = None) -> tf.Tensor:
""" Wrapper for the decorated (@tf.function) call of the simulation loop.
:param dataset:
:param voxel_batch_size: (int) see documentation of _simulate_segmented_k_space.
:param unstack_repetitions: If False the returned shape will be
(#images, #Reps, #noise_levels, #samples).
If True refer to 'return' in docstring.
:return: **(tf.Tensor)** | Stack of images with shape
(#images, ..., #noise_levels, #X, #Y, #Z), or k-spaces with shape
(#images, ..., #noise_levels, #samples) where the ellipsis (...) depends on the
specific forward model. The order in which the *CompositeSignalModel* calls the
submodules determines the order of dimensions and number of axis in the ellipsis.
"""
self._update()
self._validate_dataset(dataset.map_names, trajectory_module is not None)
# If wanted, disable graph construction with tf.function (designed for debugging)
self.progress_bar.total_voxels.assign(dataset.set_size)
# Run actual simulation
k_space_shape = self.get_k_space_shape()
batched_dataset = dataset(batchsize=voxel_batch_size)
noise_less_k_space_stack = self._simulation_loop(batched_dataset, k_space_shape,
trajectory_module, trajectory_signatures,
additional_kwargs)
# Act noise_instantiations
k_space_stack = self.encoding_module.add_noise(noise_less_k_space_stack)
if unstack_repetitions:
k_space_stack = self.forward_model.unstack_repetitions(k_space_stack)
if self.recon_module is not None:
simulation_result = self._reconstruct(k_space_stack)
else:
simulation_result = k_space_stack
return simulation_result
def _validate_dataset(self, map_names: Tuple[str], tr_mod: bool):
# Check dataset inputs for missing parameter maps
if tr_mod:
if not all((rq in map_names for rq in self.forward_model.required_quantities if rq != 'r_vectors')):
raise AssertionError(
f'{map_names} does not contain all entries of'
f' {self.forward_model.required_quantities}')
else:
if not all((rq in map_names for rq in self.forward_model.required_quantities)):
raise AssertionError(
f'{map_names} does not contain all entries of'
f' {self.forward_model.required_quantities}')
def _simulation_loop(self, dataset: tf.data.Dataset, k_space_shape: tf.Tensor,
trajectory_module=None, trajectory_signatures: dict = None,
additional_kwargs: dict = None):
""" Consumes all object-configuration data (images) and simulates the MR images. The
simulation configuration is
exclusively defined by the given modules forward_model and encoding. If the optional
reconstruction module is given, the returned tensor contains images otherwise it will
return the k-space data.
:param dataset: tf.Dataset that yields corresponding to the
`cmrsim.datasets.BaseDataset.__call__`
:return: Stack of noise_less k-space-images as tensor with shape
(#repetitions, #samples). The returned images are only 3D if the
input data, encoding and recon module are 3D as well.
"""
if additional_kwargs is None:
additional_kwargs = {}
# Allocate tensor array to store the simulated images
s_of_k_temp = tf.zeros(k_space_shape, tf.complex64)
self.progress_bar.reset_voxel()
# Loop over material points in dataset
for batch_idx, batch_dict in dataset.enumerate():
if trajectory_module is not None:
batch_dict = self._map_trajectories(batch_idx, batch_dict, trajectory_module,
trajectory_signatures, additional_kwargs)
m_transverse = self.forward_model(batch_dict["M0"], segment_index=0, **batch_dict)
s_of_k_temp += self.encoding_module(m_transverse, batch_dict["r_vectors"])
self.progress_bar.update(add_voxels=tf.shape(batch_dict['M0'])[0])
self.progress_bar.print()
self.progress_bar.print_final()
return s_of_k_temp
@staticmethod
def _map_trajectories(idx: int, batch: dict, trajectory_module,
trajectory_signatures: dict, additional_kwargs: dict):
"""
:warning: Only the last entry of the trajectory signatures dict
:param idx:
:param batch:
:param trajectory_module:
:param trajectory_signatures:
:param additional_kwargs:
:return:
"""
init_r = batch.pop("initial_positions")
for k, t in trajectory_signatures.items():
new_shape = tf.concat([tf.shape(init_r)[0:1], tf.shape(t), [3, ]], axis=0)
# Iterate over repetitions as for some reason this might cause troubles if tried
# using the flattened time-samples
pos, add_lookups = [], []
for t_ in t:
p, alup = trajectory_module(initial_positions=tf.reshape(init_r, (-1, 3)),
timing=tf.reshape(t_, (-1, )),
**additional_kwargs, batch_index=tf.cast(idx, tf.int32))
pos.append(p)
add_lookups.append(alup)
pos = tf.stack(pos, axis=1)
pos = tf.reshape(pos, new_shape)
batch.update({k: pos})
lookup_shapes = [tf.concat([new_shape[:-1], tf.shape(add_lookups[-1][k])[2:]], 0)
for k in add_lookups[-1].keys()]
add_lookups = {k: tf.reshape(tf.stack([alup[k] for alup in add_lookups]), lus)
for k, lus in zip(add_lookups[-1].keys(), lookup_shapes)}
batch.update(add_lookups)
return batch
[docs]
def get_k_space_shape(self):
""" Calculates the expected result shape of the simulated k-space, given the configuration
of encoding and forward-model modules
:return: tf.Tensor specifying the shape (#repetitions, #k_space_samples)
"""
n_reps = tf.cast(self.forward_model.expected_number_of_repetitions, tf.int32)
n_samples = tf.cast(self.encoding_module.number_of_samples, tf.int32)
expected_k_space_shape = tf.stack([n_reps, n_samples], axis=0)
return expected_k_space_shape
def _reconstruct(self, simulated_k_spaces):
return self.recon_module(simulated_k_spaces)
def _update(self):
""" In case there are dependencies between the modules of the simulation, this function
offers the entry point to adapt all Variables after changes
"""
self.forward_model.update()
self.encoding_module.update()
[docs]
def save(self, checkpoint_path: str):
""" Saves a tf checkpoint of the current simulation configuration
:param checkpoint_path: str
"""
checkpoint = tf.train.Checkpoint(model=self)
checkpoint.write(checkpoint_path)
[docs]
@classmethod
def from_checkpoint(cls, checkpoint_path: str):
""" Creates instance of the class and loads variables from the specified checkpoint
:param checkpoint_path: str
:return: Instance of class
"""
new_model = cls(build_kwargs=dict())
new_checkpoint = tf.train.Checkpoint(model=new_model)
new_checkpoint.restore(checkpoint_path)
return new_model
@property
def configuration_summary(self) -> OrderedDict:
""" Creates a summary of the configured variables in the submodules as dictionaries
:return: dict
"""
summary = OrderedDict()
temp = {'modules': []}
for module in self.forward_model.submodules:
temp['modules'].append(module.name)
temp.update({module.name: {}})
for variable in module.variables:
temp[module.name].update(
{variable.name: {'name': variable.name, 'dtype': str(variable.dtype),
'shape': str(variable.shape),
'is_trainable': str(variable.trainable),
'value': variable.numpy().tolist()}})
summary.update({'forward model': temp})
summary.update({'encoding': {'name': self.encoding_module.name}})
for variable in self.encoding_module.variables:
summary['encoding'].update(
{variable.name: {'name': variable.name, 'dtype': str(variable.dtype),
'shape': str(variable.shape),
'is_trainable': str(variable.trainable),
'value': variable.numpy().tolist()}})
return summary
[docs]
def write_graph(self, dataset, graph_log_dir: str):
""" Traces the graph of the simulation for one batch and saves it as graphdef to the
specified folder
"""
# If set to True log the graph with tf.summary
os.makedirs(graph_log_dir, exist_ok=True)
writer = tf.summary.create_file_writer(graph_log_dir)
print("Graph Tracing activated...")
tf.summary.trace_on(graph=True)
batch_dict, = [_ for _ in dataset(100).take(1)]
m_transverse = self.forward_model(batch_dict["M0"], segment_index=0, **batch_dict)
_ = self.encoding_module(m_transverse, batch_dict["r_vectors"])
with writer.as_default(): # pylint: disable not-context-manager
tf.summary.trace_export(name='graph_def', step=0)
print(f'\nSaved Graph-definition to {graph_log_dir}')