Source code for cmrsim.datasets._base

""" This module contains the base implementation of a Dataset containing a digital phantom defined
 as a dictionary of numpy arrays, fitting into RAM at once"""

__all__ = ["BaseDataset"]

from abc import abstractmethod

from typing import Optional, Tuple
from collections import OrderedDict

import numpy as np
import tensorflow as tf


[docs] class BaseDataset(tf.Module): """ Basic implementation of a module that constructs a tf.data.Dataset from the dictionary of numpy arrays, describing the digital phantom. """ #: Names of simulation quantities passed as dictionary keys on construction map_names: Tuple = None #: Number of anatomies (0th - axis) of the passed simulation quantities set_size: int = None _array_dict: dict #: Indices to obtain the filtered arrays from the original inputs filter_indices: np.ndarray def __init__(self, filter_inputs, array_dictionary): """ :param array_dictionary: (OrderedDict) containing the required quantities as numpy arrays :param filter_inputs: If set to true, trivial material points are filtered (M0=0), on instantiation """ super().__init__() if filter_inputs: self._array_dict, self.filter_indices = self._filter_inputs(array_dictionary) else: self._array_dict = array_dictionary self.filter_indices = None
[docs] def __call__(self, batchsize: int = 1000, prefetch: int = 5) -> (int, tf.data.Dataset): """ Returns a nested tf.data.Dataset, where the inner dataset represents the digital phantom per image. These datasets are batch and prefetch, and the batch yielded is a dictionary like : {'M0', tf.Tensor(...), ...}. The keys of the dict are strings, while the values each are Tensors representing a batch of iso-chromates/grid-positions of the flattened datasets. The return dataset is supposed to be iterated as: .. code:: for batch in dataset(batchsize=500): # batch = {magnetization:( ), trajectories: ( ), T1: (-1), ....} ... :param batchsize: (int) batch size :param prefetch: (int) prefetched batches :return: tf.data.Dataset """ datasets = tf.data.Dataset.from_tensor_slices(self._array_dict) return datasets.batch(batchsize).prefetch(prefetch)
@staticmethod def _filter_inputs(array_dict): """ Per-image filter function, to select only grid-points that have an initial magnetization > 0. :return: """ m0 = array_dict["M0"] indices = tf.where(tf.abs(m0) > 0.)[:, 0] new_dict = OrderedDict([(k, tf.gather(v, indices, axis=0)) for k, v in array_dict.items()]) return new_dict, indices