Proper orthogonal decomposition (POD) trajectories#

The POD trajectory module computes the proper orthogonal decomposition for mesh data of shape (#Mesh nodes, #Time-steps, #Channels) and fit a taylor expansion to weight-functions. In this notebook the use-case of a contracting LV mesh including fiber and sheet orientation is demonstrated.

Import#

[1]:
import sys
sys.path.append("../")
sys.path.insert(0, "../../")

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

import tensorflow as tf
gpu = tf.config.get_visible_devices("GPU")
if gpu:
    tf.config.experimental.set_memory_growth(gpu[0], True)
[2]:
from IPython.display import display, HTML, clear_output
import base64
import imageio

import pyvista
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pint import Quantity

import cmrsim
import local_functions

Load cardiac mesh#

The full resource can be downloaded and unzipped into the local folder “test_data” as follows

[4]:
!wget https://gitlab.ethz.ch/ibt-cmr/modeling/cmr-random-diffmaps/-/archive/main/cmr-random-diffmaps-main.zip?path=notebooks/example_data -O test_data.zip
!unzip test_data.zip
!rm test_data.zip -r
!mv cmr-random-diffmaps-main-notebooks-example_data/notebooks/example_data test_data
!rm cmr-random-diffmaps-main-notebooks-example_data -r
clear_output()

Now load the meshes (from vtk files) and the timings as defined in the text file PV_loop_reordered.txt

[5]:
files = [f'test_data/example_mesh/Displacements/Displ{i:03}.vtk' for i in range(0, 200, 1)]
timing = Quantity(np.loadtxt("test_data/example_mesh/PV_loop_reordered.txt")[:, 1], "ms")

original_mesh = cmrsim.datasets.CardiacMeshDataset.from_list_of_meshes(files, timing[:len(files)],
                                                                       mesh_numbers=(4, 40, 30),
                                                                       time_precision_ms=3)
Loading Files... : 100%|██████████| 199/199 [00:03<00:00, 55.42it/s]

Explanation/Exploration of the POD results#

To give an insight into how the POD module works, the first section demonstrates the underlying computation and highlights intermediate results.

Caclulate POD#

Usually, this is done on instantiation of the module. The explicit call to evaluate the decomposition as well as its visualization is only for demonstration purpose.

[7]:
t, r = original_mesh.get_trajectories(start=Quantity(0, "ms"),
                                      end=Quantity(5000, "ms"))
n_modes = 5
phi, weights = cmrsim.trajectory.PODTrajectoryModule.calculate_pod(t.m_as("ms"),  data=r.m_as("m"), n_modes=n_modes)
print(phi.shape, weights.shape)
(14412, 5) (199, 5)

Visualize the modes of computed POD#

[8]:
pyvista.close_all()
pyvista.start_xvfb()
plotter = pyvista.Plotter(off_screen=True, window_size=(1600, 400), shape=(1, n_modes))
for n in range(n_modes):
    plotter.subplot(0, n)
    plotter.add_title(f"Mode {n}", font_size=12)
    plotter.add_mesh(phi.T[n].reshape(-1, 3), render_points_as_spheres=True)
img = plotter.screenshot("POD_modes.png")
pyvista.close_all()
b64 = base64.b64encode(open("POD_modes.png", 'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))

Plot mode contributions over time and render the reconstructed mesh#

Reconstructing a motion state \(u(t)\) of the mesh is done by calculated the weighted sum of the mesh modes for each time point:

\[u(t) = \Sigma_j^{N_{modes}} \phi_j w_j(t),\]

where \(\phi_j\) are the computed basis functions (modes) shown above and \(w_j(t)\) are the corresponding mode-weights as funtion of time shown below. The motion states reconstructed by 5 modes is shown as animation below.

[12]:
# Plot mode weights as function of time
fig, ax = plt.subplots(1, 1, figsize=(12, 5), dpi=100)
ax.plot(t.m, weights);
ax.set_xlabel("Time (ms)"), ax.grid(alpha=0.5), ax.set_title("Mode contributions over time")
ax.legend(["$\\phi_{i}$".format(i=n) for n in range(0, phi.shape[1])], ncol=5)
fig.tight_layout()
fig.savefig("modecontributions.png")
plt.close("all")
clear_output()

# Render reconstructed mesh
t_indices = np.linspace(0, t.shape[0]-1, 50).astype(int)
phi_dash = phi.T.reshape(n_modes, -1, 3)
trajectories = np.stack([np.sum((phi_dash * weights[t_idx].reshape(-1, 1, 1)), axis=0) for t_idx in tqdm(t_indices)])

plotter = pyvista.Plotter(off_screen=False, window_size=(500, 500), shape=(1, 1))
local_functions.animate_trajectories(plotter, trajectories,
                                     timing=[t.m_as("ms")[_] for _ in t_indices],
                                     filename="PODRecon.gif",
                                     mesh_kwargs=dict(render_points_as_spheres=True),
                                     visualize_trajectories=False,
                                     trajectory_kwargs=dict(opacity=0.5, color="red"))
pyvista.close_all()

# Combine figures and display
imgs = [imageio.v3.imread(fname) for fname in ["modecontributions.png", "PODRecon.gif"]]
stacked_frames = [np.concatenate([imgs[0][..., :3], frame], axis=1) for frame in imgs[1]]
imageio.mimsave("meshrecon.gif", stacked_frames, loop=10, duration=0.3)
b64 = base64.b64encode(open("meshrecon.gif", 'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))

Calling the POD module#

In the this second section the POD module is instantiated and called, where the mechanism above is used internally on instantiating the module. The mode-weights are represented by a Taylor-series of specified order, which allows to reconstruct the motion state arbitrary times within the time-interval of given mesh-snapshots.

[13]:
tf.config.run_functions_eagerly(True)
# Get mesh-snap shots
t, r = original_mesh.get_trajectories(start=Quantity(0, "ms"), end=Quantity(5000, "ms"))

# Define number of modes and degree of Taylor series to represent the mode-weights
n_modes = 5
poly_order = 10

# Instantiate module
module = cmrsim.trajectory.PODTrajectoryModule(t.m_as("ms"), r.m_as("m"), n_modes=n_modes, poly_order=poly_order)
# Or do batched calls:
# module = cmrsim.trajectory.PODTrajectoryModule(t.m_as("ms"), data=r.m_as("m"), n_modes=n_modes, poly_order=poly_order, batch_size=1000)

# Define new time-points
new_time_grid = np.linspace(t[0].m, t[-1].m, 100).astype(np.float32)

# Get reconstructed mesh at new time-points
reconstructed_mesh_points, _ = module(None, new_time_grid)

# (batched call continued)
reconstructed_mesh_points = tf.concat([module(None, new_time_grid, batch_index=i)[0] for i in range(10)], axis=0)

clear_output()
reconstructed_mesh_points.shape
[13]:
TensorShape([4804, 100, 3])

By accessing the super-class _evaluate_trajectory method (this is only for demonstration purpose) we can also get the mode-weights at the new timings

Plot the results of reconstruction#

[14]:
mode_weights = tf.transpose(module._taylor_module._evaluate_trajectory(new_time_grid)[:, :, 0], [1, 0])
[17]:
# Plot mode weights as function of time
fig, ax = plt.subplots(1, 1, figsize=(12, 5), dpi=100)
ax.plot(new_time_grid, mode_weights);
ax.set_xlabel("Time (ms)"), ax.grid(alpha=0.5), ax.set_title("Mode contributions over time (module call)")
ax.legend(["$\\phi_{i}$".format(i=n) for n in range(0, phi.shape[1])], ncol=5)
fig.tight_layout()
fig.savefig("modecontributions.png")
plt.close("all")
clear_output()

# Render reconstructed mesh
plotter = pyvista.Plotter(off_screen=False, window_size=(500, 500), shape=(1, 1))
local_functions.animate_trajectories(plotter, np.swapaxes(reconstructed_mesh_points.numpy(), 0, 1),
                                     timing=new_time_grid,
                                     filename="PODRecon.gif",
                                     mesh_kwargs=dict(render_points_as_spheres=True),
                                     visualize_trajectories=False,
                                     trajectory_kwargs=dict(opacity=0.5, color="red"))
plotter.close()

# Combine figures and display
imgs = [imageio.v3.imread(fname) for fname in ["modecontributions.png", "PODRecon.gif"]]
stacked_frames = [np.concatenate([imgs[0][..., :3], frame], axis=1) for frame in imgs[1]]
imageio.mimsave("meshrecon.gif", stacked_frames, loop=4, duration=0.3)
b64 = base64.b64encode(open("meshrecon.gif", 'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))