Advice request: Lazily sampled slice/plane rendering #1059
-
|
Hello! I'm not sure if this is best filed as a "discussion" or "issue", so please feel free to move it. I am trying to make a UI similar to the slice rendering in "Volume Slice" examples, but the slices are sampled lazily from volume stored in a file on disk so it can be used with a larger-than-memory image. Additionally, I would like to be able to control the position of the plane to be rendered with the Transform Gizmo (super cool widget by the way!). See the movie below for a little demo. plane_render.movThe way I implemented this in my initial hacky prototype (code below) is to make a subclass of the ( My questions are:
Thank you! Hacky prototype script# /// script
# requires-python = ">=3.10"
# dependencies = [
# "einops",
# "glfw",
# "numpy",
# "pygfx",
# "scipy",
# ]
# ///
"""Attempt at lazy slice rendering with an Image node."""
import einops
import numpy as np
import pygfx as gfx
from pylinalg import vec_transform
from rendercanvas.auto import RenderCanvas, loop
from scipy.ndimage import map_coordinates
edge_length = 50
half_length = edge_length // 2
grid_shape = (edge_length, edge_length)
grid_spacing = (1, 1,)
class EventedTransformGizmo(gfx.TransformGizmo):
"""This class adds a callback that gets exectued whenever the gizmo updates the object.
The callback is called _on_object_update(). In this case it updates the Image node
texture. It transforms a grid of sampling points based on the Image node's transform matrix,
and then samples the volume at those points.
"""
def process_event(self, event):
"""This is mostly copied from the parent class.
I just added the calls to self._self._on_object_update() in
the "pointer_move" and "pointer_up" cases.
"""
# No interaction if there is no object to control
if not self._object_to_control:
return
# Triage over event type
type = event.type
if type == "pointer_down":
if event.button != 1 or event.modifiers:
return
self._ref = None
# NOTE: I imagine that if multiple tools are active, they
# each ask for picking info, causing multiple buffer reads
# for the same location. However, with the new event system
# this is probably not a problem, when wobjects receive events.
ob = event.target
if ob not in self.children:
return
# Depending on the object under the pointer, we scale/translate/rotate
if ob == self._center_sphere:
self._handle_start("scale", event, ob)
elif ob in self._translate_children:
self._handle_start("translate", event, ob)
elif ob in self._scale_children:
self._handle_start("scale", event, ob)
elif ob in self._rotate_children:
self._handle_start("rotate", event, ob)
# Highlight the object
self._highlight(ob)
self._viewport.renderer.request_draw()
self.set_pointer_capture(event.pointer_id, event.root)
elif type == "pointer_up":
if not self._ref:
return
if self._ref["dim"] is None and self._ref["maxdist"] < 3:
self.toggle_mode() # clicked on the center sphere
self._ref = None
# De-highlight the object
self._highlight()
# do a callback to update the object this gizmo is attached to
self._on_object_update()
self._viewport.renderer.request_draw()
elif type == "pointer_move":
if not self._ref:
return
# Get how far we've moved from starting point - we have a dead zone
dist = (
(event.x - self._ref["event_pos"][0]) ** 2
+ (event.y - self._ref["event_pos"][1]) ** 2
) ** 0.5
self._ref["maxdist"] = max(self._ref["maxdist"], dist)
# Delegate to the correct handler
if self._ref["maxdist"] < 3:
pass
elif self._ref["kind"] == "translate":
self._handle_translate_move(event)
elif self._ref["kind"] == "scale":
self._handle_scale_move(event)
elif self._ref["kind"] == "rotate":
self._handle_rotate_move(event)
# do a callback to update the object this gizmo is attached to
self._on_object_update()
# Keep viz up to date
self._viewport.renderer.request_draw()
def _on_object_update(self):
"""Callback whenever the gizmo updates the object.
This can eventually be a real signal/slot system, but
this was a quick hack to get it working.
This works by transforming a grid of sampling points based
on the Image node's transform matrix, and then sampling the
volume at those points.
"""
print("Updating object")
transform_matrix = self._object_to_control.world.matrix
matrix = transform_matrix[:3, :3]
matrix_flipped = matrix[::-1, ::-1]
translation = transform_matrix[:3, 3]
translation_flipped = translation[::-1]
print(translation_flipped)
# reshape the sampling grid to be a list of coordinates
grid_coords = sampling_grid.reshape(-1, 3)
# apply the transform to the grid
new_transform = np.eye(4, dtype=np.float32)
new_transform[:3, :3] = matrix_flipped
new_transform[:3, 3] = translation_flipped
transformed_grid = vec_transform(grid_coords, new_transform)
# apply the translation
print(f" min_corner: {transformed_grid.min(axis=0)}")
print(f" max_corner: {transformed_grid.max(axis=0)}")
sampled_volume = map_coordinates(
image,
transformed_grid.reshape(-1, 3).T,
order=0,
cval=0,
)
print(f" labels: {np.unique(sampled_volume)}")
tex.data[:] = sampled_volume.reshape(grid_shape)
tex.update_full()
canvas.request_draw()
def generate_3d_grid(
grid_shape: tuple[int, int, int] = (10, 10, 10),
grid_spacing: tuple[float, float, float] = (1, 1, 1),
) -> np.ndarray:
"""
Generate a 3D sampling grid with specified shape and spacing.
The grid generated is centered on the origin, has shape (w, h, d, 3) for
grid_shape (w, h, d), and spacing grid_spacing between neighboring points.
Parameters
----------
grid_shape : Tuple[int, int, int]
The number of grid points along each axis.
grid_spacing : Tuple[float, float, float]
Spacing between points in the sampling grid.
Returns
-------
np.ndarray
Coordinate of points forming the 3D grid.
"""
# generate a grid of points at each integer from 0 to grid_shape for each dimension
grid = np.indices(grid_shape).astype(float)
grid = einops.rearrange(grid, "xyz w h d -> w h d xyz")
# shift the grid to be centered on the origin
# grid_offset = (np.array(grid_shape)) // 2
# grid -= grid_offset
# scale the grid to get correct spacing
grid *= grid_spacing
return grid
def generate_2d_grid(
grid_shape: tuple[int, int] = (10, 10), grid_spacing: tuple[float, float] = (1, 1)
) -> np.ndarray:
"""
Generate a 2D sampling grid with specified shape and spacing.
The grid generated is centered on the origin, lying on the plane with normal
vector [1, 0, 0], has shape (w, h, 3) for grid_shape (w, h), and spacing
grid_spacing between neighboring points.
Parameters
----------
grid_shape : Tuple[int, int]
The number of grid points along each axis.
grid_spacing : Tuple[float, float]
Spacing between points in the sampling grid.
Returns
-------
np.ndarray
Coordinate of points forming the 2D grid.
"""
grid = generate_3d_grid(
grid_shape=(1, *grid_shape), grid_spacing=(1, *grid_spacing)
)
return einops.rearrange(grid, "1 w h xyz -> w h xyz")
def create_3d_cube_quadrants(size=8):
"""
Create a 3D NumPy array representing a cube where each octant
has a unique value from 0 to 7, using ZYX indexing.
5-------6
/| /|
/ | / |
7-------8 |
| 1---|--2 |
| / | / |
3-------4 /
\ /
\ /
Bottom
Octant Values:
1: Bottom Back Left (0, 0, 0)
2: Bottom Back Right (0, 1, 0)
3: Bottom Front Left (0, 0, 1)
4: Bottom Front Right (0, 1, 1)
5: Top Back Left (1, 0, 0)
6: Top Back Right (1, 1, 0)
7: Top Front Left (1, 0, 1)
8: Top Front Right (1, 1, 1)
Parameters:
-----------
size : int, optional (default=8)
The edge length of each dimension of the cube
Returns:
--------
np.ndarray
A 3D array with 8 distinct octant values
np.ndarray
(8,3) array of centroids for each octant
"""
# Create an empty cube
cube = np.zeros((size, size, size), dtype=np.float32)
# Calculate the midpoint
mid = size // 2
# Assign values to each octant using direct indexing (ZYX order)
cube[0:mid, 0:mid, 0:mid] = 1 # Bottom Back Left
cube[0:mid, mid:, 0:mid] = 2 # Bottom Back Right
cube[0:mid, 0:mid, mid:] = 3 # Bottom Front Left
cube[0:mid, mid:, mid:] = 4 # Bottom Front Right
cube[mid:, 0:mid, 0:mid] = 5 # Top Back Left
cube[mid:, mid:, 0:mid] = 6 # Top Back Right
cube[mid:, 0:mid, mid:] = 7 # Top Front Left
cube[mid:, mid:, mid:] = 8 # Top Front Right
# normalize the cube to be in the range [0, 1]
cube /= 9
# get the centroids of each octant
centroids = np.array(
[
[mid / 2, mid / 2, mid / 2], # 0: Bottom Back Left
[mid / 2, mid * 3 / 2, mid / 2], # 1: Bottom Back Right
[mid / 2, mid / 2, mid * 3 / 2], # 2: Bottom Front Left
[mid / 2, mid * 3 / 2, mid * 3 / 2], # 3: Bottom Front Right
[mid * 3 / 2, mid / 2, mid / 2], # 4: Top Back Left
[mid * 3 / 2, mid * 3 / 2, mid / 2], # 5: Top Back Right
[mid * 3 / 2, mid / 2, mid * 3 / 2], # 6: Top Front Left
[mid * 3 / 2, mid * 3 / 2, mid * 3 / 2] # 7: Top Front Right
],
dtype=np.float32
)
return cube, centroids
# make the image
image, coordinates = create_3d_cube_quadrants(size=edge_length)
# color map
color_names = [
"Black",
"White",
"Blue",
"Magenta",
"Green",
"Cyan",
"Orange",
"Gray",
"Yellow"
]
colors = np.array(
[
[0.0, 0.0, 0.0], # Background: Black
[1.0, 1.0, 1.0], # Octant 0: White
[0.0, 0.0, 1.0], # Octant 1: Blue (standard blue)
[1.0, 0.0, 1.0], # Octant 2: Magenta (pure magenta)
[0.0, 1.0, 0.0], # Octant 3: Green (pure green)
[0.0, 1.0, 1.0], # Octant 4: Cyan (pure cyan)
[1.0, 0.5, 0.0], # Octant 5: Orange (standard orange)
[0.5, 0.5, 0.5], # Octant 6: Gray (medium gray)
[1.0, 1.0, 0.0] # Octant 7: Yellow (pure yellow)
],
dtype=np.float32,
)
colormap = gfx.TextureMap(
texture=gfx.Texture(colors, dim=1), filter="nearest", wrap="clamp"
)
# make the scene/canvas
canvas = RenderCanvas()
renderer = gfx.renderers.WgpuRenderer(canvas)
viewport = gfx.Viewport(renderer)
scene = gfx.Scene()
# add the axes
scene.add(gfx.AxesHelper(size=40, thickness=5))
# make the volume node
geometry = gfx.Geometry(grid=image)
material = gfx.VolumeIsoMaterial(clim=(0, 1), threshold=0, map=colormap)
volume_node = gfx.Volume(geometry, material, visible=False)
scene.add(volume_node)
bounding_box = gfx.BoxHelper(color="red")
bounding_box.set_transform_by_object(volume_node)
scene.add(bounding_box)
# make the image node
tex = gfx.Texture(image[0, ...], dim=2)
image_node = gfx.Image(
gfx.Geometry(grid=tex),
gfx.ImageBasicMaterial(clim=(0, 1), interpolation="nearest", map=colormap),
)
scene.add(image_node)
# make the points node
points_geometry = gfx.Geometry(
positions=coordinates[:, [2, 1, 0]], # ZYX order
colors=colors[1:]
)
points_material = gfx.PointsMaterial(
size=5,
color_mode="vertex"
)
points_node = gfx.Points(points_geometry, points_material)
scene.add(points_node)
# make the camera
camera = gfx.PerspectiveCamera(70, 16 / 9)
camera.show_object(scene, view_dir=(-1, -1, -1), up=(0, 0, 1))
controller = gfx.OrbitController(camera, register_events=renderer)
# add the gizmo
gizmo = EventedTransformGizmo(image_node)
gizmo.add_default_event_handlers(viewport, camera)
sampling_grid = generate_2d_grid(
grid_shape=grid_shape, grid_spacing=grid_spacing
)
def on_key_press(event):
if event.key == "q":
# toggle the visibility of the volume
if volume_node.visible:
volume_node.visible = False
else:
volume_node.visible = True
def animate():
viewport.render(scene, camera)
viewport.render(gizmo, camera)
renderer.flush()
renderer.add_event_handler(on_key_press, "key_down")
print("Point coordinates:")
for octant_index, (coordinate, color) in enumerate(
zip(coordinates, color_names[1:])
):
print(f" Octant {octant_index + 1} coordinates: {coordinate}, color: {color}")
if __name__ == "__main__":
canvas.request_draw(animate)
loop.run()
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
|
Cool proof of concept!
Reacting to events is still a thing we're ironing out. We will likely remove all events from the renderer (including "before_render"). The following should work though, and should keep working in the future: Lines 196 to 201 in b6757cc Then when the local matrix has changed, update the slice.
mmm, I think the gizmo is "stuck" to the object that it tracks. One possibility is to move the geometry of the image so it's origin is in the center of the image.
No, but I think it makes perfect sense to be able to do that. If you're up for it, you could create a PR and we'd help you implement it. |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for the thorough reply @almarklein !
Okay, super good to know.
I think this works, thank you (updated prototype below)! From my understanding, it looks like I have to perform my check of the updated prototype# /// script
# requires-python = ">=3.10"
# dependencies = [
# "einops",
# "glfw",
# "numpy",
# "pygfx",
# "scipy",
# ]
# ///
"""Attempt at lazy slice rendering with an Image node."""
import einops
import numpy as np
import pygfx as gfx
from pylinalg import vec_transform
from rendercanvas.auto import RenderCanvas, loop
from scipy.ndimage import map_coordinates
edge_length = 50
half_length = edge_length // 2
grid_shape = (edge_length, edge_length)
grid_spacing = (1, 1,)
class LazyImageSlice(gfx.Image):
def _update_object(self):
local_last_modified = self.local.last_modified
if local_last_modified > self._world_last_modified:
# if the transform has been updated, reslice the object
print("Updating object")
transform_matrix = self.world.matrix
matrix = transform_matrix[:3, :3]
matrix_flipped = matrix[::-1, ::-1]
translation = transform_matrix[:3, 3]
translation_flipped = translation[::-1]
print(translation_flipped)
# reshape the sampling grid to be a list of coordinates
grid_coords = sampling_grid.reshape(-1, 3)
# apply the transform to the grid
new_transform = np.eye(4, dtype=np.float32)
new_transform[:3, :3] = matrix_flipped
new_transform[:3, 3] = translation_flipped
transformed_grid = vec_transform(grid_coords, new_transform)
# apply the translation
print(f" min_corner: {transformed_grid.min(axis=0)}")
print(f" max_corner: {transformed_grid.max(axis=0)}")
sampled_volume = map_coordinates(
image,
transformed_grid.reshape(-1, 3).T,
order=0,
cval=0,
)
print(f" labels: {np.unique(sampled_volume)}")
self.geometry.grid.data[:] = sampled_volume.reshape(grid_shape)
self.geometry.grid.update_full()
super()._update_object()
def generate_3d_grid(
grid_shape: tuple[int, int, int] = (10, 10, 10),
grid_spacing: tuple[float, float, float] = (1, 1, 1),
) -> np.ndarray:
"""
Generate a 3D sampling grid with specified shape and spacing.
The grid generated is centered on the origin, has shape (w, h, d, 3) for
grid_shape (w, h, d), and spacing grid_spacing between neighboring points.
Parameters
----------
grid_shape : Tuple[int, int, int]
The number of grid points along each axis.
grid_spacing : Tuple[float, float, float]
Spacing between points in the sampling grid.
Returns
-------
np.ndarray
Coordinate of points forming the 3D grid.
"""
# generate a grid of points at each integer from 0 to grid_shape for each dimension
grid = np.indices(grid_shape).astype(float)
grid = einops.rearrange(grid, "xyz w h d -> w h d xyz")
# shift the grid to be centered on the origin
# grid_offset = (np.array(grid_shape)) // 2
# grid -= grid_offset
# scale the grid to get correct spacing
grid *= grid_spacing
return grid
def generate_2d_grid(
grid_shape: tuple[int, int] = (10, 10), grid_spacing: tuple[float, float] = (1, 1)
) -> np.ndarray:
"""
Generate a 2D sampling grid with specified shape and spacing.
The grid generated is centered on the origin, lying on the plane with normal
vector [1, 0, 0], has shape (w, h, 3) for grid_shape (w, h), and spacing
grid_spacing between neighboring points.
Parameters
----------
grid_shape : Tuple[int, int]
The number of grid points along each axis.
grid_spacing : Tuple[float, float]
Spacing between points in the sampling grid.
Returns
-------
np.ndarray
Coordinate of points forming the 2D grid.
"""
grid = generate_3d_grid(
grid_shape=(1, *grid_shape), grid_spacing=(1, *grid_spacing)
)
return einops.rearrange(grid, "1 w h xyz -> w h xyz")
def create_3d_cube_quadrants(size=8):
"""
Create a 3D NumPy array representing a cube where each octant
has a unique value from 0 to 7, using ZYX indexing.
5-------6
/| /|
/ | / |
7-------8 |
| 1---|--2 |
| / | / |
3-------4 /
\ /
\ /
Bottom
Octant Values:
1: Bottom Back Left (0, 0, 0)
2: Bottom Back Right (0, 1, 0)
3: Bottom Front Left (0, 0, 1)
4: Bottom Front Right (0, 1, 1)
5: Top Back Left (1, 0, 0)
6: Top Back Right (1, 1, 0)
7: Top Front Left (1, 0, 1)
8: Top Front Right (1, 1, 1)
Parameters:
-----------
size : int, optional (default=8)
The edge length of each dimension of the cube
Returns:
--------
np.ndarray
A 3D array with 8 distinct octant values
np.ndarray
(8,3) array of centroids for each octant
"""
# Create an empty cube
cube = np.zeros((size, size, size), dtype=np.float32)
# Calculate the midpoint
mid = size // 2
# Assign values to each octant using direct indexing (ZYX order)
cube[0:mid, 0:mid, 0:mid] = 1 # Bottom Back Left
cube[0:mid, mid:, 0:mid] = 2 # Bottom Back Right
cube[0:mid, 0:mid, mid:] = 3 # Bottom Front Left
cube[0:mid, mid:, mid:] = 4 # Bottom Front Right
cube[mid:, 0:mid, 0:mid] = 5 # Top Back Left
cube[mid:, mid:, 0:mid] = 6 # Top Back Right
cube[mid:, 0:mid, mid:] = 7 # Top Front Left
cube[mid:, mid:, mid:] = 8 # Top Front Right
# normalize the cube to be in the range [0, 1]
cube /= 9
# get the centroids of each octant
centroids = np.array(
[
[mid / 2, mid / 2, mid / 2], # 0: Bottom Back Left
[mid / 2, mid * 3 / 2, mid / 2], # 1: Bottom Back Right
[mid / 2, mid / 2, mid * 3 / 2], # 2: Bottom Front Left
[mid / 2, mid * 3 / 2, mid * 3 / 2], # 3: Bottom Front Right
[mid * 3 / 2, mid / 2, mid / 2], # 4: Top Back Left
[mid * 3 / 2, mid * 3 / 2, mid / 2], # 5: Top Back Right
[mid * 3 / 2, mid / 2, mid * 3 / 2], # 6: Top Front Left
[mid * 3 / 2, mid * 3 / 2, mid * 3 / 2] # 7: Top Front Right
],
dtype=np.float32
)
return cube, centroids
# make the image
image, coordinates = create_3d_cube_quadrants(size=edge_length)
# color map
color_names = [
"Black",
"White",
"Blue",
"Magenta",
"Green",
"Cyan",
"Orange",
"Gray",
"Yellow"
]
colors = np.array(
[
[0.0, 0.0, 0.0], # Background: Black
[1.0, 1.0, 1.0], # Octant 0: White
[0.0, 0.0, 1.0], # Octant 1: Blue (standard blue)
[1.0, 0.0, 1.0], # Octant 2: Magenta (pure magenta)
[0.0, 1.0, 0.0], # Octant 3: Green (pure green)
[0.0, 1.0, 1.0], # Octant 4: Cyan (pure cyan)
[1.0, 0.5, 0.0], # Octant 5: Orange (standard orange)
[0.5, 0.5, 0.5], # Octant 6: Gray (medium gray)
[1.0, 1.0, 0.0] # Octant 7: Yellow (pure yellow)
],
dtype=np.float32,
)
colormap = gfx.TextureMap(
texture=gfx.Texture(colors, dim=1), filter="nearest", wrap="clamp"
)
# make the scene/canvas
canvas = RenderCanvas()
renderer = gfx.renderers.WgpuRenderer(canvas)
viewport = gfx.Viewport(renderer)
scene = gfx.Scene()
# add the axes
scene.add(gfx.AxesHelper(size=40, thickness=5))
# make the volume node
geometry = gfx.Geometry(grid=image)
material = gfx.VolumeIsoMaterial(clim=(0, 1), threshold=0, map=colormap)
volume_node = gfx.Volume(geometry, material, visible=False)
scene.add(volume_node)
bounding_box = gfx.BoxHelper(color="red")
bounding_box.set_transform_by_object(volume_node)
scene.add(bounding_box)
# make the image node
tex = gfx.Texture(np.zeros_like(image[0, ...]), dim=2)
image_node = LazyImageSlice(
gfx.Geometry(grid=tex),
gfx.ImageBasicMaterial(clim=(0, 1), interpolation="nearest", map=colormap),
)
scene.add(image_node)
# make the points node
points_geometry = gfx.Geometry(
positions=coordinates[:, [2, 1, 0]], # ZYX order
colors=colors[1:]
)
points_material = gfx.PointsMaterial(
size=5,
color_mode="vertex"
)
points_node = gfx.Points(points_geometry, points_material)
scene.add(points_node)
# make the camera
camera = gfx.PerspectiveCamera(70, 16 / 9)
camera.show_object(scene, view_dir=(-1, -1, -1), up=(0, 0, 1))
controller = gfx.OrbitController(camera, register_events=renderer)
# add the gizmo
gizmo = gfx.TransformGizmo(image_node)
gizmo.add_default_event_handlers(viewport, camera)
sampling_grid = generate_2d_grid(
grid_shape=grid_shape, grid_spacing=grid_spacing
)
def on_key_press(event):
if event.key == "q":
# toggle the visibility of the volume
if volume_node.visible:
volume_node.visible = False
else:
volume_node.visible = True
def animate():
viewport.render(scene, camera)
viewport.render(gizmo, camera)
renderer.flush()
renderer.add_event_handler(on_key_press, "key_down")
print("Point coordinates:")
for octant_index, (coordinate, color) in enumerate(
zip(coordinates, color_names[1:])
):
print(f" Octant {octant_index + 1} coordinates: {coordinate}, color: {color}")
if __name__ == "__main__":
canvas.request_draw(animate)
loop.run()
I am definitely interested - not sure I have time though. I'll see if I can free up some time in the coming weeks to give it a go. Thanks again! |
Beta Was this translation helpful? Give feedback.
-
yes that makes sense! |
Beta Was this translation helpful? Give feedback.
Cool proof of concept!
Reacting to events is still a thing we're ironing out. We will likely remove all events from the renderer (including "before_render").
The following should work though, and should keep working in the future:
The
WorldObjectclass has_update_object()which gets called on every draw. You could subclass Image and implement that (don't forget to callsuper()._update_object(). Then comparelocal.last_modifiedto a new variableself._local_last_modified, similar to what we do here to sync the world matrix:pygfx/pygfx/objects/_base.py
Lines 196 to 201 in b6757cc