{ "cells": [ { "cell_type": "markdown", "id": "a8754069", "metadata": {}, "source": [ "# Plot with Mayavi" ] }, { "cell_type": "markdown", "id": "0920d72d", "metadata": {}, "source": [ "We recommend using a separate plotting environment for Mayavi. The required dependencies can be found in README.md." ] }, { "cell_type": "code", "execution_count": 1, "id": "bb07d726-8e4d-463f-84f6-7708c3baadfd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "objc[1781]: Class RunLoopModeTracker is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d370) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c55d0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[1781]: Class QT_ROOT_LEVEL_POOL__THESE_OBJECTS_WILL_BE_RELEASED_WHEN_QAPP_GOES_OUT_OF_SCOPE is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d2f8) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c5698). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n", "objc[1781]: Class KeyValueObserver is implemented in both /opt/anaconda3/envs/mayavi_env/lib/libQt5Core.5.15.15.dylib (0x10ed3d320) and /opt/anaconda3/envs/mayavi_env/lib/libQt6Core.6.9.3.dylib (0x3174c56c0). This may cause spurious casting failures and mysterious crashes. One of the duplicates must be removed or renamed.\n" ] } ], "source": [ "import numpy as np\n", "from mayavi import mlab\n", "import scanpy as sc\n", "import os\n", "\n", "def visualize_ot_mapping_3d(\n", " X1, X2, P,\n", " labels1, labels2,\n", " label_colors=None,\n", " z1=0.0, z2=1.0,\n", " top_k=1,\n", " prob_threshold=0.0,\n", " max_edges=800,\n", " edge_opacity=0.15,\n", " edge_color=(0.2, 0.2, 0.2),\n", " edge_thickness = None,\n", " point_size=4.0,\n", " bg_color=(1, 1, 1),\n", " figure_size=(1000, 800),\n", " seed=0,\n", " save=False,\n", " savename = None,\n", "):\n", " \"\"\"\n", " Visualize OT mapping between two slices using Mayavi in 3D.\n", "\n", " Parameters\n", " ----------\n", " X1, X2 : array-like\n", " Coordinates of points in slice 1 and slice 2.\n", " Shape: (n1, 2) or (n1, 3), (n2, 2) or (n2, 3).\n", " P : array-like\n", " OT matrix of shape (n1, n2).\n", " labels1, labels2 : array-like\n", " Cell-type labels (e.g. strings or ints) for X1 and X2.\n", " label_colors : dict or None\n", " Optional mapping {label: (r,g,b)} with values in [0,1].\n", " If None, colors are assigned automatically.\n", " z1, z2 : float\n", " z coordinates where the two slices will be placed.\n", " top_k : int\n", " For each source point in X1, keep at most top_k target connections.\n", " prob_threshold : float\n", " Ignore edges with P[i,j] < prob_threshold.\n", " max_edges : int\n", " Maximum number of edges to draw in total. If exceeded, subsample.\n", " edge_opacity : float\n", " Opacity of edges (0–1).\n", " edge_color : tuple\n", " RGB color of edges, values in [0,1].\n", " point_size : float\n", " Size of points (radius-like).\n", " bg_color : tuple\n", " Background color of the figure.\n", " figure_size : tuple\n", " Size of the Mayavi figure (width, height).\n", " seed : int\n", " Random seed for edge subsampling.\n", " \"\"\"\n", " X1 = np.asarray(X1)\n", " X2 = np.asarray(X2)\n", " P = np.asarray(P)\n", "\n", " n1, n2 = P.shape\n", " assert X1.shape[0] == n1\n", " assert X2.shape[0] == n2\n", "\n", " # Handle 2D vs 3D input\n", " def ensure_3d(X, z_plane):\n", " if X.shape[1] == 2:\n", " x, y = X[:, 0], X[:, 1]\n", " z = np.full_like(x, z_plane, dtype=float)\n", " elif X.shape[1] == 3:\n", " x, y, _ = X[:, 0], X[:, 1], X[:, 2]\n", " z = np.full_like(x, z_plane, dtype=float) # flatten in z\n", " else:\n", " raise ValueError(\"X must have 2 or 3 columns.\")\n", " return x, y, z\n", "\n", " x1, y1, z1_arr = ensure_3d(X1, z1)\n", " x2, y2, z2_arr = ensure_3d(X2, z2)\n", "\n", " # ----- Build a color map for labels -----\n", " labels1 = np.asarray(labels1)\n", " labels2 = np.asarray(labels2)\n", " all_labels = np.unique(np.concatenate([labels1, labels2]))\n", "\n", " if label_colors is None:\n", " # auto-assign colors (distinct-ish)\n", " rng = np.random.default_rng(42)\n", " label_colors = {}\n", " for lab in all_labels:\n", " color = rng.random(3) * 0.7 + 0.3 # avoid too dark\n", " label_colors[lab] = tuple(color)\n", " else:\n", " # ensure every label has a color\n", " missing = [lab for lab in all_labels if lab not in label_colors]\n", " if missing:\n", " rng = np.random.default_rng(42)\n", " for lab in missing:\n", " color = rng.random(3) * 0.7 + 0.3\n", " label_colors[lab] = tuple(color)\n", "\n", " # ----- Create figure -----\n", " mlab.figure(bgcolor=bg_color, size=figure_size)\n", "\n", " # ----- Plot points for each slice, colored by label -----\n", " # Slice 1\n", " for lab in np.unique(labels1):\n", " mask = (labels1 == lab)\n", " if not np.any(mask):\n", " continue\n", " mlab.points3d(\n", " x1[mask], y1[mask], z1_arr[mask],\n", " scale_factor=point_size,\n", " color=label_colors[lab],\n", " mode=\"sphere\",\n", " resolution=16,\n", " opacity=1.0,\n", " )\n", "\n", " # Slice 2\n", " for lab in np.unique(labels2):\n", " mask = (labels2 == lab)\n", " if not np.any(mask):\n", " continue\n", " mlab.points3d(\n", " x2[mask], y2[mask], z2_arr[mask],\n", " scale_factor=point_size,\n", " color=label_colors[lab],\n", " mode=\"sphere\",\n", " resolution=16,\n", " opacity=1.0,\n", " )\n", "\n", " # ----- Select a subset of OT edges to avoid clutter -----\n", " edges = []\n", "\n", " # Top-k per source cell\n", " for i in range(n1):\n", " row = P[i]\n", " if top_k >= len(row):\n", " top_idx = np.argsort(row)[::-1] # descending\n", " else:\n", " # indices of top_k values\n", " top_idx = np.argpartition(row, -top_k)[-top_k:]\n", " top_idx = top_idx[np.argsort(row[top_idx])[::-1]]\n", " for j in top_idx:\n", " w = row[j]\n", " if w >= prob_threshold and w > 0:\n", " edges.append((i, j, w))\n", "\n", " # If still too many, subsample globally\n", " if len(edges) > max_edges:\n", " rng = np.random.default_rng(seed)\n", " idx = rng.choice(len(edges), size=max_edges, replace=False)\n", " edges = [edges[k] for k in idx]\n", "\n", " # ----- Draw edges as thin semi-transparent lines -----\n", " for (i, j, w) in edges:\n", " # optional: modulate opacity or width by w\n", " # alpha = edge_opacity * (w / max_w) etc. if you like\n", " mlab.plot3d(\n", " [x1[i], x2[j]],\n", " [y1[i], y2[j]],\n", " [z1_arr[i], z2_arr[j]],\n", " tube_radius=edge_thickness,\n", " tube_sides=32,\n", " line_width=1.0,\n", " color=edge_color,\n", " opacity=edge_opacity,\n", " )\n", "\n", " # Optional: adjust view\n", " mlab.orientation_axes()\n", " mlab.view(azimuth=60, elevation=75, distance='auto')\n", " mlab.axes().parent.visible = False\n", "\n", " if save:\n", " fig = mlab.gcf()\n", " fig.scene.anti_aliasing_frames = 16 # or 4 if it gets too slow\n", " mlab.savefig(savename, size=(3000, 2400))\n", " else:\n", " mlab.show()\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "eed8ef58-d0a1-4045-83bd-d6d9ae262cc4", "metadata": {}, "outputs": [], "source": [ "adata_datadir = \"./data/\"\n", "\n", "raftup_datadir = \"./results/align_data/\"\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "a2bcecda-bff2-4ec1-be40-9aeebfc01b4c", "metadata": {}, "outputs": [], "source": [ "def hex_to_rgb01(hex_str):\n", " \"\"\"Convert '#RRGGBB' to (r,g,b) floats in [0,1].\"\"\"\n", " hex_str = hex_str.lstrip('#')\n", " return tuple(int(hex_str[i:i+2], 16) / 255.0 for i in (0, 2, 4))\n", "\n", "layer_color_hex = {\n", " 0: \"#FBB4AE\",\n", " 1: \"#B3CDE3\",\n", " 2: \"#CCEBC5\",\n", " 3: \"#DECBE4\",\n", " 4: \"#FED9A6\",\n", " 5: \"#FFFFCC\",\n", " 6: \"#E5D8BD\",\n", "}\n", "\n", "layer_color = {k: hex_to_rgb01(v) for k, v in layer_color_hex.items()}" ] }, { "cell_type": "code", "execution_count": 6, "id": "ecadb8aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "151508 151509\n", "Saved figure to ./results/figures/151508_151509_full_global.png\n" ] } ], "source": [ "os.makedirs(\"./results/figures\", exist_ok=True)\n", "\n", "d_1 = \"151508\"\n", "d_2 = \"151509\"\n", "print(d_1, d_2)\n", "\n", "adata_1 = sc.read_h5ad(adata_datadir + \"%s/%s_commot_distance.h5ad\" % (d_1, d_1))\n", "adata_2 = sc.read_h5ad(adata_datadir + \"%s/%s_commot_distance.h5ad\" % (d_2, d_2))\n", "\n", "X1 = adata_1.obsm[\"spatial\"]\n", "X2 = adata_2.obsm[\"spatial\"]\n", "l1 = np.array(adata_1.obs[\"original_clusters\"], dtype=int)\n", "l2 = np.array(adata_2.obs[\"original_clusters\"], dtype=int)\n", "\n", "P_raftup = np.load(raftup_datadir + f\"full_matching_{d_1}_{d_2}_206_0.4.npy\")\n", "\n", "save_path = f\"./results/figures/{d_1}_{d_2}_full_global.png\" \n", "\n", "visualize_ot_mapping_3d(\n", " X1, X2, P_raftup,\n", " l1, l2,\n", " z1=0.0,\n", " z2=3500.0,\n", " top_k=1,\n", " prob_threshold=1e-5,\n", " max_edges=1000,\n", " edge_opacity=0.5,\n", " edge_color=(0.3, 0.3, 0.3),\n", " label_colors=layer_color,\n", " point_size=75.0,\n", " save=True,\n", " savename=save_path\n", ")\n", "\n", "print(f\"Saved figure to {save_path}\")" ] } ], "metadata": { "kernelspec": { "display_name": "raftup", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }