Contracting LV with diffusion random walk#

In this notebook, the DiffusionTaylor trajectory module is presented. This trajectory module combines deterministic particle trajectories with local random walk diffusional motion. To this end, the module fits a taylor expansion to the deterministic trajectories and the eigenvector orientation of a 3x3 tensor capturing gaussian diffusion.

Imports#

[1]:
from IPython.display import display, HTML, clear_output
import base64

import tensorflow as tf
gpu = tf.config.get_visible_devices("GPU")[0]
tf.config.set_visible_devices(gpu, device_type="GPU")
tf.config.experimental.set_memory_growth(gpu, True)

import imageio
import pyvista
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pint import Quantity

import sys
sys.path.insert(0, "../../")
sys.path.append("..")
import cmrsim
import local_functions

Load determinisitic trajectories#

The node-trajectories of a sliced contracting left ventricle mesh is used to define deterministic part for the trajectory module. The diffusion tensor eigen-vectors are defined by the fiber and sheet directions stored in the reference mesh. In the cells below, the mesh is loaded and rendered including the fiber and sheet directions.

[2]:
files = [f'../../../cmr-random-diffmaps/notebooks/example_data/example_mesh/Displacements/Displ{i:03}.vtk' for i in range(0, 73, 1)]
timing = Quantity(np.loadtxt("../../../cmr-random-diffmaps/notebooks/example_data/example_mesh/PV_loop_reordered.txt")[:, 1], "ms")

original_mesh = cmrsim.datasets.CardiacMeshDataset.from_list_of_meshes(files, timing[:len(files)],
                                                                       mesh_numbers=(4, 40, 30),
                                                                       time_precision_ms=3)
refined_mesh = original_mesh.refine(longitudinal_factor=2, circumferential_factor=2, radial_factor=3)
slice_dict = dict(slice_thickness=Quantity(10, "mm"), slice_normal=np.array((0., 0., 1.)),
                  slice_position=Quantity([0, 0, -3], "cm"), reference_time = Quantity(105, "ms"))

temp = refined_mesh.select_slice(**slice_dict)
slice_mesh = cmrsim.datasets.MeshDataset(temp, refined_mesh.timing)


reference_mesh = pyvista.read("../example_resources/mesh_displacements/ED_reference_mesh.vtk")
slice_mesh.probe_reference(reference_mesh, ["fibers", "sheets"], reference_time=None)
slice_mesh.transform_vectors(slice_mesh.timing, ["fibers", "sheets"], reference_time=None, rotation_only=False)

Render animations

[3]:
plotter = pyvista.Plotter(off_screen=True, window_size=(600, 600), theme=local_functions.get_custom_theme())
slice_mesh.render_input_animation("vectors_slice_fiber", plotter=plotter,
                                  start=Quantity(0, "ms"), end=None, vector="fibers",
                                  mesh_kwargs=dict(opacity=0., show_scalar_bar=False),
                                  vector_kwargs=dict(mag=5e-3, color="red"),
                                  text_kwargs={'position': 'upper_right', 'color': 'black', 'font_size': 16})
plotter.close()
plotter = pyvista.Plotter(off_screen=True, window_size=(600, 600), theme=local_functions.get_custom_theme())
slice_mesh.render_input_animation("vectors_slice_sheets", plotter=plotter,
                                  start=Quantity(0, "ms"), end=None, vector="sheets",
                                  mesh_kwargs=dict(opacity=0., show_scalar_bar=False),
                                  vector_kwargs=dict(mag=5e-3, color="blue"),
                                  text_kwargs={'position': 'upper_right', 'color': 'black', 'font_size': 16})
plotter.close()

gif1 = imageio.v3.imread(f"vectors_slice_fiber.gif")
gif2 = imageio.v3.imread(f"vectors_slice_sheets.gif")
combined_gif = imageio.get_writer(f'meshmesh.gif')
for f1, f2 in tqdm(zip(gif1, gif2), total=gif1.shape[0]):
    combined_gif.append_data(np.concatenate([f1, f2], axis=1))
combined_gif.close()
b64 = base64.b64encode(open("meshmesh.gif",'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))

Store trajectories, fibers and sheets as intermediate results

[4]:
t, fibers_over_time = slice_mesh.get_field("fibers")
_, sheets_over_time = slice_mesh.get_field("sheets")
# fibers_over_time /= np.linalg.norm(fibers_over_time, axis=-1, keepdims=True)
# sheets_over_time /= np.linalg.norm(sheets_over_time, axis=-1, keepdims=True)
ev3_over_time = np.cross(fibers_over_time, sheets_over_time)
local_basis = np.swapaxes(np.concatenate([fibers_over_time, sheets_over_time, ev3_over_time], axis=-1), 0, 1)
local_basis.shape

particle_trajectories = Quantity(np.swapaxes(slice_mesh.get_field("displacements")[1] + slice_mesh.mesh.points, 0, 1), "m")
diffusivities = Quantity(np.repeat([[1.8, 1.2, 0.7], ], particle_trajectories.shape[0], axis=0) * 0.1, "mm^2/s")
np.savez("temp", local_basis=local_basis, particle_trajectories=particle_trajectories.m, diffusivities=diffusivities.m, time=t.m)

The DiffusionTaylor module#

Fit Taylor expansion for positions and local_basis#

[5]:
fi = np.load("temp.npz")
local_basis, particle_trajectories, diffusivities, time = [fi[k] for k in ('local_basis', 'particle_trajectories', 'diffusivities', 'time')]
particle_trajectories = Quantity(particle_trajectories, "m")
diffusivities = Quantity(diffusivities, "mm^2/s") * int(1e4)
t = Quantity(time, "ms")
[6]:
mod = cmrsim.trajectory.DiffusionTaylor(order=5, time_grid=t,
                                        particle_trajectories=particle_trajectories,
                                        local_basis_over_time=local_basis,
                                        diffusivities=diffusivities,
                                        particles_per_node=100)

Evaluate particle trajectories#

[7]:
tf.config.run_functions_eagerly(True)
timing = np.arange(0., 20, 0.05).astype(np.float32)

res, fields = mod(timing, 0.01, return_eigenbasis=True)
res.shape, fields[1].keys(), res[1, 0, 0, 0:3].shape, fields[1]["fibers"].shape
100%|██████████████████████████████████████████████████████████████████████████████| 2315/2315 [00:24<00:00, 95.51it/s]
[7]:
((400, 100, 5756, 3),
 dict_keys(['fibers', 'sheets', 'ev3']),
 (3,),
 TensorShape([5756, 3]))

Plot trajectory-projections as interactive widget#

[ ]:
import ipywidgets as widgets

%matplotlib widget

plt.close("all")
f, axes = plt.subplots(1, 3, figsize=(14, 4))

skips = 30
point_artists =  [a.plot(res[0, :, ::skips, i].reshape(-1), res[0, :, ::skips, j].reshape(-1), ".", markersize=1, color="black")[0]
                    for a, (i, j) in zip(axes, [(0, 1), (0, 2), (1, 2)])]

quiver_artists = [a.quiver(res[0, 0, ::skips, i], res[0, 0, ::skips, j],
                           fields[0]["fibers"][::skips, i], fields[0]["fibers"][::skips, j],
                  scale=15, color="red") for a, (i, j) in zip(axes, [(0, 1), (0, 2), (1, 2)])]

quiver_artists2 = [a.quiver(res[0, 0, ::skips, i], res[0, 0, ::skips, j],
                           fields[0]["sheets"][::skips, i], fields[0]["sheets"][::skips, j],
                  scale=15, color="blue") for a, (i, j) in zip(axes, [(0, 1), (0, 2), (1, 2)])]

quiver_artists3 = [a.quiver(res[0, 0, ::skips, i], res[0, 0, ::skips, j],
                           fields[0]["ev3"][::skips, i], fields[0]["ev3"][::skips, j],
                  scale=15, color="green") for a, (i, j) in zip(axes, [(0, 1), (0, 2), (1, 2)])]

[a.set_title(s) for a, s in zip(axes, ["xy", "xz", "zy"])]
@widgets.interact(t=widgets.IntSlider(0, min=0, max=res.shape[0]-1, step=1))
def _upd(t):
    for artp, artq, artq2, artq3, (i, j) in zip(point_artists, quiver_artists, quiver_artists2, quiver_artists3, [(0, 1), (0, 2), (1, 2)]):
        artp.set_data(res[t, :, ::skips, i].reshape(-1), res[t, :, ::skips, j].reshape(-1))
        artq.set_offsets(res[t, 0, ::skips, (i,j)].T)
        artq.set_UVC(fields[t]["fibers"][::skips, i], fields[t]["fibers"][::skips, j])
        artq2.set_offsets(res[t, 0, ::skips, (i,j)].T)
        artq2.set_UVC(fields[t]["sheets"][::skips, i], fields[t]["sheets"][::skips, j])
        artq3.set_offsets(res[t, 0, ::skips, (i,j)].T)
        artq3.set_UVC(fields[t]["ev3"][::skips, i], fields[t]["ev3"][::skips, j])

# ani = FuncAnimation(f, _upd, frames=tqdm(range(0, 1500, 15)))
# ani.save("particle_trajectories.gif")

Render 3D animation of a single node#

[8]:
pyvista.close_all()
n_nodes = 100
n_step = res[::5].shape[0]


plotter = pyvista.Plotter(off_screen=True, window_size=(500, 500))
act = plotter.add_mesh(res[-1, :, 0:1].reshape(-1, 3), render_points_as_spheres=True, color="green")
local_functions.animate_trajectories(plotter, res[::5, :, 0:1].reshape(n_step, -1, 3), timing[::5],
                                     filename=f"single_node.gif",
                                     mesh_kwargs=dict(render_points_as_spheres=True, color="white"),
                                     vectors=[[f[key][0:1].numpy() for f in fields] for key in ["fibers", "sheets", "ev3"] ],
                                     vector_pos=res[::5, 0, 0:1],
                                     multi_vecs=True,
                                     vector_kwargs=[dict(mag=0.5e-3, color=color) for color in ["red", "blue", "green"]],
                                     text_kwargs=dict(font_size=10))

pyvista.close_all()
plotter = pyvista.Plotter(off_screen=True, window_size=(500, 500))
local_functions.animate_trajectories(plotter, res[::5, :, 0:n_nodes].reshape(n_step, -1, 3), timing[::5],
                                     filename="multiple_nodes.gif",
                                     mesh_kwargs=dict(render_points_as_spheres=True, color="white"),
                                     vectors=[[f[key][0:n_nodes].numpy() for f in fields] for key in ["fibers", "sheets", "ev3"] ],
                                     vector_pos=res[::5, 0, 0:n_nodes],
                                     multi_vecs=True,
                                     vector_kwargs=[dict(mag=2.5e-3, color=color) for color in ["red", "blue", "green"]],
                                     text_kwargs=dict(font_size=10)
                                    )
clear_output()
[9]:
gif1 = imageio.v3.imread(f"single_node.gif")
gif2 = imageio.v3.imread(f"multiple_nodes.gif")
combined_gif = imageio.get_writer(f'meshmesh.gif')
for f1, f2 in tqdm(zip(gif1, gif2), total=gif1.shape[0]):
    combined_gif.append_data(np.concatenate([f1, f2], axis=1))
combined_gif.close()
b64 = base64.b64encode(open("meshmesh.gif",'rb').read()).decode('ascii')
display(HTML(f'<img src="data:image/gif;base64,{b64}" />'))