DLPFC Gene Cost by SOMDE
In this file, the embedding from DLPFC data will be generated. SOMDE will be used to select the highly variable genes. If you don’t have the data, Please run the 0_somde_robust_scanpy.ipynb to generate the data.
[1]:
import os
import random
import numpy as np
import torch
import scanpy as sc
import pandas as pd
from raftup import _gene_cost_somde
# =========================================================
# User options
# =========================================================
mode = "fast" # choose from: "reproducible", "fast"
seed = 0
ROOT_DIR = "./data"
SOMDE_DIR = "./results/somde"
OUT_DIR = "./results/gene_cost_matrix"
N_TOP_SV_GENES = 3000
SECTION_PAIRS = [
("151508", "151509")
]
# =========================================================
# Setup
# =========================================================
def setup_environment(mode="reproducible", seed=0):
"""
Configure random seeds and device.
mode = "reproducible"
- force CPU
- enable deterministic algorithms
- optionally limit thread-level nondeterminism
mode = "fast"
- prefer CUDA, then MPS, then CPU
- allow nondeterministic behavior for speed
"""
if mode not in {"reproducible", "fast"}:
raise ValueError("mode must be 'reproducible' or 'fast'")
# ----- reproducible mode -----
if mode == "reproducible":
# Optional: reduce CPU nondeterminism from parallel reductions
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads(1)
torch.use_deterministic_algorithms(True)
device = torch.device("cpu")
# ----- fast mode -----
else:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# allow faster execution
torch.use_deterministic_algorithms(False)
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
print(f"Mode: {mode}")
print(f"Using device: {device}")
print("OMP_NUM_THREADS:", os.environ.get("OMP_NUM_THREADS"))
print("MKL_NUM_THREADS:", os.environ.get("MKL_NUM_THREADS"))
print("torch num threads:", torch.get_num_threads())
print("torch interop threads:", torch.get_num_interop_threads())
return device
def load_DLPFC(root_dir: str, section_id: str):
"""
Load one DLPFC Visium slice with ground-truth layer labels.
"""
adata = sc.read_visium(
path=os.path.join(root_dir, section_id),
count_file=f"{section_id}_filtered_feature_bc_matrix.h5",
)
adata.var_names_make_unique()
gt_dir = os.path.join(root_dir, section_id, "gt")
gt_df = pd.read_csv(
os.path.join(gt_dir, "tissue_positions_list_GTs.txt"),
sep=",",
header=None,
index_col=0,
)
adata.obs["original_clusters"] = gt_df.loc[:, 6]
keep_bcs = adata.obs.dropna().index
adata = adata[keep_bcs].copy()
adata.obs["original_clusters"] = (
adata.obs["original_clusters"].astype(int).astype(str)
)
print(f"[load_DLPFC] Loaded section {section_id}")
return adata
# =========================================================
# Main
# =========================================================
device = setup_environment(mode=mode, seed=seed)
os.makedirs(OUT_DIR, exist_ok=True)
for sid_A, sid_B in SECTION_PAIRS:
print(f"\n=== Computing gene cost: {sid_A} vs {sid_B} ===")
sliceA = load_DLPFC(ROOT_DIR, sid_A)
sliceB = load_DLPFC(ROOT_DIR, sid_B)
for sl in (sliceA, sliceB):
sc.pp.normalize_total(sl)
sc.pp.log1p(sl)
df_somde_A = pd.read_csv(f"{SOMDE_DIR}/somde_{sid_A}.csv")
df_somde_B = pd.read_csv(f"{SOMDE_DIR}/somde_{sid_B}.csv")
sv_genes_A = df_somde_A["g"].values[:N_TOP_SV_GENES]
sv_genes_B = df_somde_B["g"].values[:N_TOP_SV_GENES]
sliceA = sliceA[:, sv_genes_A].copy()
sliceB = sliceB[:, sv_genes_B].copy()
for sl in (sliceA, sliceB):
sc.pp.scale(sl)
_gene_cost_somde.compute_gene_cost(
sliceA=sliceA,
sliceB=sliceB,
section_id_A=sid_A,
section_id_B=sid_B,
n_h=100,
n_epoch=3500,
lr=2e-4,
print_step=500,
seed=seed,
device=device,
output_dir=OUT_DIR,
)
print("\nAll gene cost matrices computed successfully.")
Mode: fast
Using device: mps
OMP_NUM_THREADS: None
MKL_NUM_THREADS: None
torch num threads: 12
torch interop threads: 12
=== Computing gene cost: 151508 vs 151509 ===
[load_DLPFC] Loaded section 151508
[load_DLPFC] Loaded section 151509
[0000] DGI loss = 2.8241
[0500] DGI loss = 0.0010
[1000] DGI loss = 0.0003
[1500] DGI loss = 0.0002
[2000] DGI loss = 0.0001
[2500] DGI loss = 0.0000
[3000] DGI loss = 0.0000
[3499] DGI loss = 0.0001
All gene cost matrices computed successfully.