Plot with Mayavi

We recommend using a separate plotting environment for Mayavi. The required dependencies can be found in README.md.

[1]:
import numpy as np
from mayavi import mlab
import scanpy as sc
import os

def visualize_ot_mapping_3d(
    X1, X2, P,
    labels1, labels2,
    label_colors=None,
    z1=0.0, z2=1.0,
    top_k=1,
    prob_threshold=0.0,
    max_edges=800,
    edge_opacity=0.15,
    edge_color=(0.2, 0.2, 0.2),
    edge_thickness = None,
    point_size=4.0,
    bg_color=(1, 1, 1),
    figure_size=(1000, 800),
    seed=0,
    save=False,
    savename = None,
):
    """
    Visualize OT mapping between two slices using Mayavi in 3D.

    Parameters
    ----------
    X1, X2 : array-like
        Coordinates of points in slice 1 and slice 2.
        Shape: (n1, 2) or (n1, 3), (n2, 2) or (n2, 3).
    P : array-like
        OT matrix of shape (n1, n2).
    labels1, labels2 : array-like
        Cell-type labels (e.g. strings or ints) for X1 and X2.
    label_colors : dict or None
        Optional mapping {label: (r,g,b)} with values in [0,1].
        If None, colors are assigned automatically.
    z1, z2 : float
        z coordinates where the two slices will be placed.
    top_k : int
        For each source point in X1, keep at most top_k target connections.
    prob_threshold : float
        Ignore edges with P[i,j] < prob_threshold.
    max_edges : int
        Maximum number of edges to draw in total. If exceeded, subsample.
    edge_opacity : float
        Opacity of edges (0–1).
    edge_color : tuple
        RGB color of edges, values in [0,1].
    point_size : float
        Size of points (radius-like).
    bg_color : tuple
        Background color of the figure.
    figure_size : tuple
        Size of the Mayavi figure (width, height).
    seed : int
        Random seed for edge subsampling.
    """
    X1 = np.asarray(X1)
    X2 = np.asarray(X2)
    P  = np.asarray(P)

    n1, n2 = P.shape
    assert X1.shape[0] == n1
    assert X2.shape[0] == n2

    # Handle 2D vs 3D input
    def ensure_3d(X, z_plane):
        if X.shape[1] == 2:
            x, y = X[:, 0], X[:, 1]
            z = np.full_like(x, z_plane, dtype=float)
        elif X.shape[1] == 3:
            x, y, _ = X[:, 0], X[:, 1], X[:, 2]
            z = np.full_like(x, z_plane, dtype=float)  # flatten in z
        else:
            raise ValueError("X must have 2 or 3 columns.")
        return x, y, z

    x1, y1, z1_arr = ensure_3d(X1, z1)
    x2, y2, z2_arr = ensure_3d(X2, z2)

    # ----- Build a color map for labels -----
    labels1 = np.asarray(labels1)
    labels2 = np.asarray(labels2)
    all_labels = np.unique(np.concatenate([labels1, labels2]))

    if label_colors is None:
        # auto-assign colors (distinct-ish)
        rng = np.random.default_rng(42)
        label_colors = {}
        for lab in all_labels:
            color = rng.random(3) * 0.7 + 0.3  # avoid too dark
            label_colors[lab] = tuple(color)
    else:
        # ensure every label has a color
        missing = [lab for lab in all_labels if lab not in label_colors]
        if missing:
            rng = np.random.default_rng(42)
            for lab in missing:
                color = rng.random(3) * 0.7 + 0.3
                label_colors[lab] = tuple(color)

    # ----- Create figure -----
    mlab.figure(bgcolor=bg_color, size=figure_size)

    # ----- Plot points for each slice, colored by label -----
    # Slice 1
    for lab in np.unique(labels1):
        mask = (labels1 == lab)
        if not np.any(mask):
            continue
        mlab.points3d(
            x1[mask], y1[mask], z1_arr[mask],
            scale_factor=point_size,
            color=label_colors[lab],
            mode="sphere",
            resolution=16,
            opacity=1.0,
        )

    # Slice 2
    for lab in np.unique(labels2):
        mask = (labels2 == lab)
        if not np.any(mask):
            continue
        mlab.points3d(
            x2[mask], y2[mask], z2_arr[mask],
            scale_factor=point_size,
            color=label_colors[lab],
            mode="sphere",
            resolution=16,
            opacity=1.0,
        )

    # ----- Select a subset of OT edges to avoid clutter -----
    edges = []

    # Top-k per source cell
    for i in range(n1):
        row = P[i]
        if top_k >= len(row):
            top_idx = np.argsort(row)[::-1]  # descending
        else:
            # indices of top_k values
            top_idx = np.argpartition(row, -top_k)[-top_k:]
            top_idx = top_idx[np.argsort(row[top_idx])[::-1]]
        for j in top_idx:
            w = row[j]
            if w >= prob_threshold and w > 0:
                edges.append((i, j, w))

    # If still too many, subsample globally
    if len(edges) > max_edges:
        rng = np.random.default_rng(seed)
        idx = rng.choice(len(edges), size=max_edges, replace=False)
        edges = [edges[k] for k in idx]

    # ----- Draw edges as thin semi-transparent lines -----
    for (i, j, w) in edges:
        # optional: modulate opacity or width by w
        # alpha = edge_opacity * (w / max_w) etc. if you like
        mlab.plot3d(
            [x1[i], x2[j]],
            [y1[i], y2[j]],
            [z1_arr[i], z2_arr[j]],
            tube_radius=edge_thickness,
            tube_sides=32,
            line_width=1.0,
            color=edge_color,
            opacity=edge_opacity,
        )

    # Optional: adjust view
    mlab.orientation_axes()
    mlab.view(azimuth=60, elevation=75, distance='auto')
    mlab.axes().parent.visible = False

    if save:
        fig = mlab.gcf()
        fig.scene.anti_aliasing_frames = 16  # or 4 if it gets too slow
        mlab.savefig(savename, size=(3000, 2400))
    else:
        mlab.show()

objc[1781]: Class RunLoopModeTracker is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d370) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c55d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[1781]: Class QT_ROOT_LEVEL_POOL__THESE_OBJECTS_WILL_BE_RELEASED_WHEN_QAPP_GOES_OUT_OF_SCOPE is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d2f8) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c5698). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
objc[1781]: Class KeyValueObserver is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d320) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c56c0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.
[2]:
adata_datadir = "./data/"

raftup_datadir = "./results/align_data/"

[3]:
def hex_to_rgb01(hex_str):
    """Convert '#RRGGBB' to (r,g,b) floats in [0,1]."""
    hex_str = hex_str.lstrip('#')
    return tuple(int(hex_str[i:i+2], 16) / 255.0 for i in (0, 2, 4))

layer_color_hex = {
    0: "#FBB4AE",
    1: "#B3CDE3",
    2: "#CCEBC5",
    3: "#DECBE4",
    4: "#FED9A6",
    5: "#FFFFCC",
    6: "#E5D8BD",
}

layer_color = {k: hex_to_rgb01(v) for k, v in layer_color_hex.items()}
[6]:
os.makedirs("./results/figures", exist_ok=True)

d_1 = "151508"
d_2 = "151509"
print(d_1, d_2)

adata_1 = sc.read_h5ad(adata_datadir + "%s/%s_commot_distance.h5ad" % (d_1, d_1))
adata_2 = sc.read_h5ad(adata_datadir + "%s/%s_commot_distance.h5ad" % (d_2, d_2))

X1 = adata_1.obsm["spatial"]
X2 = adata_2.obsm["spatial"]
l1 = np.array(adata_1.obs["original_clusters"], dtype=int)
l2 = np.array(adata_2.obs["original_clusters"], dtype=int)

P_raftup = np.load(raftup_datadir + f"full_matching_{d_1}_{d_2}_206_0.4.npy")

save_path = f"./results/figures/{d_1}_{d_2}_full_global.png"

visualize_ot_mapping_3d(
    X1, X2, P_raftup,
    l1, l2,
    z1=0.0,
    z2=3500.0,
    top_k=1,
    prob_threshold=1e-5,
    max_edges=1000,
    edge_opacity=0.5,
    edge_color=(0.3, 0.3, 0.3),
    label_colors=layer_color,
    point_size=75.0,
    save=True,
    savename=save_path
)

print(f"Saved figure to {save_path}")
151508 151509
Saved figure to ./results/figures/151508_151509_full_global.png