Source code for cmrsim.utils.display

__all__ = ["SimulationProgressBarI", "SimulationProgressBarII"]
from typing import Tuple, Optional, TYPE_CHECKING
from abc import abstractmethod
import sys
import os
from collections import OrderedDict

import tensorflow as tf


[docs] class SimulationProgressBarI(tf.Module): """ Graph-execution compatible command line progress-bar for running simulations""" # pylint: disable=too-many-instance-attributes target: tf.Variable current: tf.Variable # pylint: disable=too-many-arguments def __init__(self, total: int, prefix: str = '', suffix: str = '', bar_width=40, print_step = 20): """ Construct a progress-bar instance """ super().__init__(name='progress_bar') with self.name_scope: self.target = tf.Variable(total, dtype=tf.int32, trainable=False, name='total_voxels') self.current = tf.Variable(0, dtype=tf.int32, trainable=False, name="current_voxel") self.bar_length = tf.constant(bar_width, dtype=tf.int32) self.prefix = prefix self.suffix = suffix self.print_step = print_step def update(self, current: int): """ Upates the value """ self.current.assign(tf.cast(current, tf.int32)) if tf.math.mod(self.current, self.print_step) == 0: self.print() def add(self, n_add: int): self.current.assign_add(tf.cast(n_add, tf.int32)) if tf.math.mod(self.current, self.print_step) == 0: self.print() def finalize(self): """Sets all counters to their total max and prints the progress bar""" self.update(self.target.read_value()) self.print() def reset(self): self.current.assign(0) @tf.function() def print(self): """ Graph-compatible definition of printing the current state of the progress bar """ def bar_string(iteration, total, length): filled_length = tf.cast(tf.math.divide_no_nan( tf.cast(length * iteration, tf.float32), tf.cast(total, tf.float32)), tf.int32) prog_bar = tf.concat((tf.repeat('X', filled_length), tf.repeat('-', length - filled_length)), 0) ret_list = tf.concat((('|',), prog_bar, ('|', tf.strings.format('{}/{}', [iteration, total]))), 0) ret = tf.strings.reduce_join(ret_list) return ret bar = bar_string(self.current.read_value(), self.target.read_value(), self.bar_length) full_string = tf.strings.reduce_join(('\r', self.prefix, '\t', bar, self.suffix)) tf.print(full_string, end=' ', output_stream=sys.stdout)
[docs] class SimulationProgressBarII(tf.Module): """ Graph-execution compatible command line progress-bar for running simulations""" # pylint: disable=too-many-instance-attributes total_voxels: tf.Variable total_segments: tf.Variable current_image: tf.Variable current_voxel: tf.Variable current_segment: tf.Variable # pylint: disable=too-many-arguments def __init__(self, total_voxels: int = 0, total_segments: int = None, prefix: str = '', suffix: str = ''): """ Construct a progress-bar instance with three sub-progress-bars. The totals of each sub-bar are given by the arguments. """ super().__init__(name='progress_bar') with self.name_scope: self.total_voxels = tf.Variable(total_voxels, dtype=tf.int32, trainable=False, name='total_voxels') self.total_segments = tf.Variable(total_segments, dtype=tf.int32, trainable=False, name='total_segments') self.current_voxel = tf.Variable(0, dtype=tf.int32, trainable=False, name="current_voxel") self.current_segment = tf.Variable(0, dtype=tf.int32, trainable=False, name="current_segment") self.bar_lengths = (20, 15) self.prefix = prefix self.suffix = suffix def update(self, add_voxels: int = None, add_segment: int = None, set_voxel: int = None, set_segment: int = None): """ Function to update the state of the progress bar either by setting the absolute value or adding a number to the sub-bars. """ if add_voxels is not None: self.current_voxel.assign_add(tf.cast(add_voxels, tf.int32)) if add_segment is not None: self.current_segment.assign_add(tf.cast(add_segment, tf.int32)) if set_voxel is not None: self.current_voxel.assign(tf.cast(set_voxel, tf.int32)) if set_segment is not None: self.current_segment.assign(tf.cast(set_segment, tf.int32)) def reset_voxel(self): """ Resets counter of the point-sub-bar to 0 """ self.current_voxel.assign(0) def reset_segments(self): """ Resets counter of the segment-sub-bar to 0 """ self.current_segment.assign(0) def print_final(self): """Sets all counters to their total max and prints the progress bar""" self.update(set_voxel=self.total_voxels.read_value(), set_segment=self.total_segments.read_value()) self.print() @tf.function def print(self): """ Graph-compatible definition of printing the current state of the progress bar """ def bar_string(iteration, total, length): filled_length = tf.cast(tf.math.divide_no_nan( tf.cast(length * iteration, tf.float32), tf.cast(total, tf.float32)), tf.int32) prog_bar = tf.concat((tf.repeat('X', filled_length), tf.repeat('-', length - filled_length)), 0) ret_list = tf.concat((('|',), prog_bar, ('|', tf.strings.format('{}/{}', [iteration, total]))), 0) ret = tf.strings.reduce_join(ret_list) return ret voxel_bar = bar_string(self.current_voxel.read_value(), self.total_voxels.read_value(), self.bar_lengths[1]) segment_bar = bar_string(self.current_segment.read_value(), self.total_segments.read_value(), self.bar_lengths[1]) full_string = tf.strings.reduce_join(('\r', self.prefix, '\t', voxel_bar, '\t', segment_bar, self.suffix)) tf.print(full_string, end=' ', output_stream=sys.stdout)