Source code for cmrsim.datasets._analytic
__all__ = ["AnalyticDataset"]
from typing import TYPE_CHECKING, Dict
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from cmrsim.datasets._base import BaseDataset
if TYPE_CHECKING:
from cmrsim.trajectory._base import BaseTrajectoryModule
[docs]
class AnalyticDataset(BaseDataset):
def __init__(self, array_dictionary: OrderedDict, filter_inputs: bool = True,
expand_dimension: bool = False):
""" Initializes a callable module, that yields an iterable tf.Dataset on call. The target
shape of the yielded batches is (#batch, #repetitions, #kspace-samples, ...)
:raises: InvalidArgument - if shape of M0 and r_vector entry of the dictionary do not match
(#MaterialPoints, #Repetitions, #k-space-samples)
(#MaterialPoints, #Repetitions, 3, #k-space-samples)
InvalidArgumentError is raised
:param array_dictionary: (OrderedDict) containing the required quantities as numpy arrays
:param filter_inputs: If set to true, trivial material points are filtered (M0=0),
on instantiation
:param expand_dimension:
:param trajectory_module:
:param trajectory_signatures:
"""
if expand_dimension:
if len(array_dictionary["M0"].shape) > 1:
dtype_list = [f"{k}: {v.shape}" for k, v in array_dictionary.items()]
raise ValueError("Arrays appear to have repetition/sample axis already. Please"
"disable the expand-dims argument. Shapes \n\t".join(dtype_list))
array_dictionary = {k: v[:, np.newaxis, np.newaxis]
for k, v in array_dictionary.items()}
if array_dictionary["M0"].dtype != np.dtype(np.complex64):
raise ValueError("M0 is not of required type complex64")
if not all([v.dtype == np.dtype(np.float32) for k, v
in array_dictionary.items() if k != "M0"]):
dtype_list = [f"{k}: {v.dtype}" for k, v in array_dictionary.items() if k != "M0"]
raise ValueError("Not all arrays are of required type np.float32: \n" +
"\n\t".join(dtype_list))
if len(array_dictionary["M0"].shape) != 3:
raise ValueError("Shape of input array M0 invalid")
if (array_dictionary.get("initial_positions") is None
and array_dictionary.get("r_vectors") is None):
raise ValueError("Either 'r_vectors' or 'initial_positions' must be specified as mandatory"
f"key in the dictionary conainting the arrays {list(array_dictionary.keys())}."
"Where the first is used when positions are pre-computed proir to simulation"
" and the latter is used in combination with a trajectory module supporting"
"batched calls. ")
if array_dictionary.get("r_vectors") is not None:
if len(array_dictionary["r_vectors"].shape) != 4:
raise ValueError("Shape of input array r_vectors invalid")
if array_dictionary["r_vectors"].shape[-1] != 3:
raise ValueError(f"Last dimension of r-vectors is not 3, it is likely that the array"
f"axes are wrongly ordered. Refer to doc-string.\n"
f"Given shape: {array_dictionary['r_vectors'].shape}")
super().__init__(filter_inputs=filter_inputs, array_dictionary=array_dictionary)
self.set_size = self._array_dict["M0"].shape[0]
self.map_names = tuple(self._array_dict.keys())