Plot with Mayavi
We recommend using a separate plotting environment for Mayavi. The required dependencies can be found in README.md.
[1]:
import numpy as np
from mayavi import mlab
import scanpy as sc
import os
def visualize_ot_mapping_3d(
X1, X2, P,
labels1, labels2,
label_colors=None,
z1=0.0, z2=1.0,
top_k=1,
prob_threshold=0.0,
max_edges=800,
edge_opacity=0.15,
edge_color=(0.2, 0.2, 0.2),
edge_thickness = None,
point_size=4.0,
bg_color=(1, 1, 1),
figure_size=(1000, 800),
seed=0,
save=False,
savename = None,
):
"""
Visualize OT mapping between two slices using Mayavi in 3D.
Parameters
----------
X1, X2 : array-like
Coordinates of points in slice 1 and slice 2.
Shape: (n1, 2) or (n1, 3), (n2, 2) or (n2, 3).
P : array-like
OT matrix of shape (n1, n2).
labels1, labels2 : array-like
Cell-type labels (e.g. strings or ints) for X1 and X2.
label_colors : dict or None
Optional mapping {label: (r,g,b)} with values in [0,1].
If None, colors are assigned automatically.
z1, z2 : float
z coordinates where the two slices will be placed.
top_k : int
For each source point in X1, keep at most top_k target connections.
prob_threshold : float
Ignore edges with P[i,j] < prob_threshold.
max_edges : int
Maximum number of edges to draw in total. If exceeded, subsample.
edge_opacity : float
Opacity of edges (0–1).
edge_color : tuple
RGB color of edges, values in [0,1].
point_size : float
Size of points (radius-like).
bg_color : tuple
Background color of the figure.
figure_size : tuple
Size of the Mayavi figure (width, height).
seed : int
Random seed for edge subsampling.
"""
X1 = np.asarray(X1)
X2 = np.asarray(X2)
P = np.asarray(P)
n1, n2 = P.shape
assert X1.shape[0] == n1
assert X2.shape[0] == n2
# Handle 2D vs 3D input
def ensure_3d(X, z_plane):
if X.shape[1] == 2:
x, y = X[:, 0], X[:, 1]
z = np.full_like(x, z_plane, dtype=float)
elif X.shape[1] == 3:
x, y, _ = X[:, 0], X[:, 1], X[:, 2]
z = np.full_like(x, z_plane, dtype=float) # flatten in z
else:
raise ValueError("X must have 2 or 3 columns.")
return x, y, z
x1, y1, z1_arr = ensure_3d(X1, z1)
x2, y2, z2_arr = ensure_3d(X2, z2)
# ----- Build a color map for labels -----
labels1 = np.asarray(labels1)
labels2 = np.asarray(labels2)
all_labels = np.unique(np.concatenate([labels1, labels2]))
if label_colors is None:
# auto-assign colors (distinct-ish)
rng = np.random.default_rng(42)
label_colors = {}
for lab in all_labels:
color = rng.random(3) * 0.7 + 0.3 # avoid too dark
label_colors[lab] = tuple(color)
else:
# ensure every label has a color
missing = [lab for lab in all_labels if lab not in label_colors]
if missing:
rng = np.random.default_rng(42)
for lab in missing:
color = rng.random(3) * 0.7 + 0.3
label_colors[lab] = tuple(color)
# ----- Create figure -----
mlab.figure(bgcolor=bg_color, size=figure_size)
# ----- Plot points for each slice, colored by label -----
# Slice 1
for lab in np.unique(labels1):
mask = (labels1 == lab)
if not np.any(mask):
continue
mlab.points3d(
x1[mask], y1[mask], z1_arr[mask],
scale_factor=point_size,
color=label_colors[lab],
mode="sphere",
resolution=16,
opacity=1.0,
)
# Slice 2
for lab in np.unique(labels2):
mask = (labels2 == lab)
if not np.any(mask):
continue
mlab.points3d(
x2[mask], y2[mask], z2_arr[mask],
scale_factor=point_size,
color=label_colors[lab],
mode="sphere",
resolution=16,
opacity=1.0,
)
# ----- Select a subset of OT edges to avoid clutter -----
edges = []
# Top-k per source cell
for i in range(n1):
row = P[i]
if top_k >= len(row):
top_idx = np.argsort(row)[::-1] # descending
else:
# indices of top_k values
top_idx = np.argpartition(row, -top_k)[-top_k:]
top_idx = top_idx[np.argsort(row[top_idx])[::-1]]
for j in top_idx:
w = row[j]
if w >= prob_threshold and w > 0:
edges.append((i, j, w))
# If still too many, subsample globally
if len(edges) > max_edges:
rng = np.random.default_rng(seed)
idx = rng.choice(len(edges), size=max_edges, replace=False)
edges = [edges[k] for k in idx]
# ----- Draw edges as thin semi-transparent lines -----
for (i, j, w) in edges:
# optional: modulate opacity or width by w
# alpha = edge_opacity * (w / max_w) etc. if you like
mlab.plot3d(
[x1[i], x2[j]],
[y1[i], y2[j]],
[z1_arr[i], z2_arr[j]],
tube_radius=edge_thickness,
tube_sides=32,
line_width=1.0,
color=edge_color,
opacity=edge_opacity,
)
# Optional: adjust view
mlab.orientation_axes()
mlab.view(azimuth=60, elevation=75, distance='auto')
mlab.axes().parent.visible = False
if save:
fig = mlab.gcf()
fig.scene.anti_aliasing_frames = 16 # or 4 if it gets too slow
mlab.savefig(savename, size=(3000, 2400))
else:
mlab.show()
objc[1781]: Class RunLoopModeTracker is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d370) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c55d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[1781]: Class QT_ROOT_LEVEL_POOL__THESE_OBJECTS_WILL_BE_RELEASED_WHEN_QAPP_GOES_OUT_OF_SCOPE is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d2f8) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c5698). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[1781]: Class KeyValueObserver is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d320) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c56c0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
[2]:
adata_datadir = "./data/"
raftup_datadir = "./results/align_data/"
[3]:
def hex_to_rgb01(hex_str):
"""Convert '#RRGGBB' to (r,g,b) floats in [0,1]."""
hex_str = hex_str.lstrip('#')
return tuple(int(hex_str[i:i+2], 16) / 255.0 for i in (0, 2, 4))
layer_color_hex = {
0: "#FBB4AE",
1: "#B3CDE3",
2: "#CCEBC5",
3: "#DECBE4",
4: "#FED9A6",
5: "#FFFFCC",
6: "#E5D8BD",
}
layer_color = {k: hex_to_rgb01(v) for k, v in layer_color_hex.items()}
[6]:
os.makedirs("./results/figures", exist_ok=True)
d_1 = "151508"
d_2 = "151509"
print(d_1, d_2)
adata_1 = sc.read_h5ad(adata_datadir + "%s/%s_commot_distance.h5ad" % (d_1, d_1))
adata_2 = sc.read_h5ad(adata_datadir + "%s/%s_commot_distance.h5ad" % (d_2, d_2))
X1 = adata_1.obsm["spatial"]
X2 = adata_2.obsm["spatial"]
l1 = np.array(adata_1.obs["original_clusters"], dtype=int)
l2 = np.array(adata_2.obs["original_clusters"], dtype=int)
P_raftup = np.load(raftup_datadir + f"full_matching_{d_1}_{d_2}_206_0.4.npy")
save_path = f"./results/figures/{d_1}_{d_2}_full_global.png"
visualize_ot_mapping_3d(
X1, X2, P_raftup,
l1, l2,
z1=0.0,
z2=3500.0,
top_k=1,
prob_threshold=1e-5,
max_edges=1000,
edge_opacity=0.5,
edge_color=(0.3, 0.3, 0.3),
label_colors=layer_color,
point_size=75.0,
save=True,
savename=save_path
)
print(f"Saved figure to {save_path}")
151508 151509
Saved figure to ./results/figures/151508_151509_full_global.png