diff --git a/.github/workflows/update-notebooks.yml b/.github/workflows/update-notebooks.yml new file mode 100644 index 0000000..b0f4fda --- /dev/null +++ b/.github/workflows/update-notebooks.yml @@ -0,0 +1,35 @@ +name: Update Colab notebooks + +on: + push: + branches: [main] + paths: + - "examples/[0-9][0-9]_*.py" + - "docs/make_notebooks.py" + +jobs: + update-notebooks: + runs-on: ubuntu-latest + permissions: + contents: write + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install jupytext + run: pip install jupytext + + - name: Regenerate notebooks + run: python docs/make_notebooks.py + + - name: Commit updated notebooks + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add docs/auto_examples/*.ipynb + git diff --staged --quiet || git commit -m "auto: regenerate Colab notebooks from .py examples [skip ci]" + git push diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..769760f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: local + hooks: + - id: update-colab-notebooks + name: Regenerate Colab notebooks + language: python + additional_dependencies: [jupytext] + entry: python docs/make_notebooks.py --stage --examples + files: ^examples/\d{2}_.*\.py$ + pass_filenames: true diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..c9088cf --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,23 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # RTD does not execute sphinx-gallery examples (GPU-dependent, slow). + # Set BRAINDEC_BUILD_GALLERY=1 locally to build with executed outputs. + jobs: + post_install: + - python -c "from braindec._version import __version__; print(__version__)" + +sphinx: + configuration: docs/conf.py + fail_on_warning: false + +python: + install: + - method: pip + path: . + extra_requirements: + - doc + - plotting diff --git a/README.md b/README.md index c371908..5652c4d 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,24 @@ The trained baseline models use in the paper can be downloaded from the OSF repo Alternatively, you can use the pre-trained model provided in the `./results/pubmed` directory in https://osf.io/dsj56/. +### Download published assets + +The package now includes an OSF downloader for the assets documented in this README. It can download individual files, predefined bundles, or whole published folders while recreating the OSF directory layout under a destination root. + +```bash +# List built-in assets and bundles +python -m braindec.fetcher --list + +# Download the example prediction bundle into the current repository +python -m braindec.fetcher --bundle example_prediction --destination_root . + +# Download the published pretrained results and baseline folders +python -m braindec.fetcher --bundle paper_results --destination_root . + +# Download a specific published folder from the OSF project +python -m braindec.fetcher --folder data/cognitive_atlas --destination_root . +``` + ### Predictions To perform predictions using the trained model, you can use the [predict.py](./braindec/predict.py) script. diff --git a/braindec/__init__.py b/braindec/__init__.py index 96cfab8..14d079b 100644 --- a/braindec/__init__.py +++ b/braindec/__init__.py @@ -1,13 +1,22 @@ """Braindec: Brain image decoder.""" -from . import dataset, embedding, loss, model, plot, train, utils # predict +from importlib import import_module __all__ = [ - "model", "dataset", - "loss", "embedding", + "fetcher", + "loss", + "model", "plot", "train", "utils", ] + + +def __getattr__(name): + if name in __all__: + module = import_module(f".{name}", __name__) + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/braindec/cogatlas.py b/braindec/cogatlas.py index 245ed88..4cab1d2 100644 --- a/braindec/cogatlas.py +++ b/braindec/cogatlas.py @@ -1,9 +1,10 @@ import json +import os.path as op +from concurrent.futures import ThreadPoolExecutor import numpy as np import pandas as pd import requests -from nimare import extract COGATLAS_URLS = { "task": "https://www.cognitiveatlas.org/api/v-alpha/task", @@ -56,6 +57,43 @@ def _get_concepts_to_tasks(relationships_df, concept_to_task=None): return concepts_to_tasks_df +def _fetch_full_task_concepts(task_ids, cache_fn, concept_to_task=None, max_workers=16): + if cache_fn is not None and op.exists(cache_fn): + concepts_to_tasks_df = pd.read_csv(cache_fn) + else: + base_url = "https://www.cognitiveatlas.org/api/v-alpha/task" + + def _fetch_one(task_id): + response = requests.get(base_url, params={"id": task_id}, timeout=30) + response.raise_for_status() + task_json = response.json() + rows = [] + for concept in task_json.get("concepts", []): + concept_id = concept.get("concept_id") + if concept_id: + rows.append({"id": concept_id, "measuredBy": task_id}) + return rows + + rows = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for task_rows in executor.map(_fetch_one, task_ids): + rows.extend(task_rows) + concepts_to_tasks_df = pd.DataFrame(rows).drop_duplicates() + if cache_fn is not None: + concepts_to_tasks_df.to_csv(cache_fn, index=False) + + if concept_to_task is not None: + extra_df = pd.DataFrame( + { + "id": list(concept_to_task.keys()), + "measuredBy": list(concept_to_task.values()), + } + ) + concepts_to_tasks_df = pd.concat([concepts_to_tasks_df, extra_df], ignore_index=True) + concepts_to_tasks_df = concepts_to_tasks_df.drop_duplicates() + return concepts_to_tasks_df + + class CognitiveAtlas: def __init__( self, @@ -131,11 +169,12 @@ def __init__( if reduced_tasks is not None: concepts_to_tasks_df = self._get_concepts_to_tasks_red(reduced_tasks) else: - cogatlas = extract.download_cognitive_atlas(data_dir=data_dir, overwrite=False) - relationships_df = pd.read_csv(cogatlas["relationships"]) - - concepts_to_tasks_df = _get_concepts_to_tasks( - relationships_df, + cache_fn = None + if data_dir is not None: + cache_fn = op.join(data_dir, "cognitive_atlas", "full_task_concepts.csv") + concepts_to_tasks_df = _fetch_full_task_concepts( + self.task_df["id"].tolist(), + cache_fn=cache_fn, concept_to_task=concept_to_task, ) @@ -148,14 +187,14 @@ def __init__( continue sel_tasks = sel_df["measuredBy"].values - indices = np.where(np.in1d(self.task_df["id"].values, sel_tasks))[0] + indices = np.where(np.isin(self.task_df["id"].values, sel_tasks))[0] self.concept_to_task_idxs.append(indices) self.process_to_concept_idxs = [] for process in self.process_names: sel_df = self.concept_df.loc[self.concept_df["cognitive_process"] == process] - indices = np.where(np.in1d(self.concept_df["id"].values, sel_df["id"].values))[0] + indices = np.where(np.isin(self.concept_df["id"].values, sel_df["id"].values))[0] self.process_to_concept_idxs.append(indices) @@ -167,7 +206,7 @@ def __init__( continue sel_concepts = sel_df["id"].values - indices = np.where(np.in1d(self.concept_df["id"].values, sel_concepts))[0] + indices = np.where(np.isin(self.concept_df["id"].values, sel_concepts))[0] self.task_to_concept_idxs.append(indices) @@ -176,11 +215,9 @@ def __init__( sel_concepts = concepts_to_tasks_df.loc[concepts_to_tasks_df["measuredBy"] == task][ "id" ].values - if task == "trm_550b54a8b30f4": - print(sel_concepts) sel_df = self.concept_df.loc[self.concept_df["id"].isin(sel_concepts)] - indices = np.where(np.in1d(self.process_ids, sel_df["id_concept_class"].values))[0] + indices = np.where(np.isin(self.process_ids, sel_df["id_concept_class"].values))[0] self.task_to_process_idxs.append(indices) @@ -198,24 +235,24 @@ def get_task_id_from_name(self, names): def get_task_idx_from_names(self, names): if isinstance(names, str): - return np.where(np.in1d(self.task_names, names))[0][0] + return np.where(np.isin(self.task_names, names))[0][0] - return [np.where(np.in1d(self.task_names, task_name))[0][0] for task_name in names] + return [np.where(np.isin(self.task_names, task_name))[0][0] for task_name in names] def get_concept_idx_from_names(self, names): if isinstance(names, str): - return np.where(np.in1d(self.concept_names, names))[0][0] + return np.where(np.isin(self.concept_names, names))[0][0] return [ - np.where(np.in1d(self.concept_names, concept_name))[0][0] for concept_name in names + np.where(np.isin(self.concept_names, concept_name))[0][0] for concept_name in names ] def get_process_idx_from_names(self, names): if isinstance(names, str): - return np.where(np.in1d(self.process_names, names))[0][0] + return np.where(np.isin(self.process_names, names))[0][0] return [ - np.where(np.in1d(self.process_names, process_name))[0][0] for process_name in names + np.where(np.isin(self.process_names, process_name))[0][0] for process_name in names ] def get_task_names_from_idx(self, task_idx): diff --git a/braindec/embedding.py b/braindec/embedding.py index 9ae05e5..093f9cc 100644 --- a/braindec/embedding.py +++ b/braindec/embedding.py @@ -6,15 +6,13 @@ import numpy as np import torch from nilearn import datasets -from nilearn.image import concat_imgs +from nilearn.image import concat_imgs, load_img, new_img_like, resample_to_img from nilearn.maskers import NiftiMapsMasker, SurfaceMapsMasker from nimare.dataset import Dataset from nimare.meta.kernel import MKDAKernel -from peft import PeftConfig, PeftModel from tqdm import tqdm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from braindec.utils import _get_device, _vol_surfimg +from braindec.utils import _get_device, _vol_surfimg, images_have_same_fov def _coordinates_to_image(dset: Dataset, kernel: str = "mkda"): @@ -51,16 +49,23 @@ def __init__( self.batch_size = batch_size if model_name == "mistralai/Mistral-7B-v0.1": + from transformers import AutoModel, AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name).to(self.device) self.max_length = 8192 if max_length is None else max_length elif model_name == "meta-llama/Llama-2-7b-chat-hf": + from transformers import AutoModelForCausalLM, AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) self.max_length = 4096 if max_length is None else max_length elif model_name == "BrainGPT/BrainGPT-7B-v0.1": + from peft import PeftConfig, PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + config = PeftConfig.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) @@ -69,6 +74,9 @@ def __init__( self.max_length = 4096 if max_length is None else max_length elif model_name == "BrainGPT/BrainGPT-7B-v0.2": + from peft import PeftConfig, PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + config = PeftConfig.from_pretrained(model_name) # The config file has path to the base model instead of the model name model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -232,15 +240,23 @@ def __init__( self.dimension = dimension self.space = space self.density = density + self._maps_cache = {} if self.atlas == "difumo": - difumo = datasets.fetch_atlas_difumo( - dimension=self.dimension, - resolution_mm=2, - legacy_format=False, - data_dir=self.nilearn_dir, - ) + difumo_kwargs = { + "dimension": self.dimension, + "resolution_mm": 2, + "data_dir": self.nilearn_dir, + } + try: + difumo = datasets.fetch_atlas_difumo( + legacy_format=False, + **difumo_kwargs, + ) + except TypeError: + difumo = datasets.fetch_atlas_difumo(**difumo_kwargs) atlas_maps = difumo.maps + self.atlas_maps = load_img(atlas_maps) else: # Implement other atlases raise ValueError(f"Atlas {atlas} not supported.") @@ -274,11 +290,75 @@ def generate_embedding(self, images) -> np.ndarray: Returns: Numpy array containing the embedding """ + if self.space == "MNI152": + return self._generate_volume_embedding(images) + if isinstance(images, list): + images = [self._sanitize_image(image) for image in images] # Concat images to improve performance images = concat_imgs(images) + else: + images = self._sanitize_image(images) + + embeddings = self.masker.fit_transform(images) + if embeddings.ndim == 1: + embeddings = embeddings[None, :] + + return embeddings + + @staticmethod + def _sanitize_image(image): + image = load_img(image) + image_data = image.get_fdata() + if np.isfinite(image_data).all(): + return image + + warnings.warn("Non-finite values detected in image data. Replacing them with zeros.") + image_data = np.nan_to_num(image_data, nan=0.0, posinf=0.0, neginf=0.0) + return new_img_like(image, image_data, copy_header=True) + + def _get_maps_data(self, reference_img): + if reference_img is None: + cache_key = ("native",) + else: + cache_key = ( + tuple(reference_img.shape[:3]), + tuple(reference_img.affine.ravel()), + ) - return self.masker.fit_transform(images) + if cache_key not in self._maps_cache: + atlas_img = self.atlas_maps + if reference_img is not None and not images_have_same_fov(atlas_img, reference_img): + atlas_img = resample_to_img(atlas_img, reference_img, interpolation="continuous") + + maps_data = atlas_img.get_fdata(dtype=np.float32) + maps_data = np.nan_to_num(maps_data, nan=0.0, posinf=0.0, neginf=0.0) + maps_gram = np.tensordot( + maps_data, + maps_data, + axes=([0, 1, 2], [0, 1, 2]), + ).astype(np.float32) + maps_gram += np.eye(maps_gram.shape[0], dtype=np.float32) * 1e-6 + self._maps_cache[cache_key] = (maps_data, maps_gram) + + return self._maps_cache[cache_key] + + def _generate_volume_embedding(self, images) -> np.ndarray: + image = self._sanitize_image(images) + if not images_have_same_fov(image, self.atlas_maps): + image = resample_to_img(image, self.atlas_maps, interpolation="continuous") + image_data = image.get_fdata(dtype=np.float32) + image_data = np.nan_to_num(image_data, nan=0.0, posinf=0.0, neginf=0.0) + if image_data.ndim == 3: + image_data = image_data[..., None] + + maps_data, maps_gram = self._get_maps_data(None) + xty = np.tensordot(maps_data, image_data, axes=([0, 1, 2], [0, 1, 2])).astype(np.float32) + embeddings = np.linalg.solve(maps_gram, xty).T + if embeddings.ndim == 1: + embeddings = embeddings[None, :] + + return embeddings def __call__(self, images) -> np.ndarray: """ diff --git a/braindec/fetcher.py b/braindec/fetcher.py index e44739c..dd1dd41 100644 --- a/braindec/fetcher.py +++ b/braindec/fetcher.py @@ -1,18 +1,23 @@ -"""Fetch data.""" +"""Download published braindec assets from OSF.""" -import hashlib +import argparse import os -import pickle -import shutil +import os.path as op +import time +from pathlib import Path -import numpy as np import pandas as pd import requests -from nilearn.datasets._utils import fetch_single_file - -from braindec.utils import get_data_dir +OSF_API_BASE = "https://api.osf.io/v2" OSF_URL = "https://osf.io/{}/download" +DEFAULT_OSF_NODE = "dsj56" +DEFAULT_PROVIDER = "osfstorage" +CHUNK_SIZE = 1024 * 1024 +DEFAULT_MAX_RETRIES = 5 +DEFAULT_RETRY_BACKOFF_SECONDS = 2 + +# Legacy term/classification files used by the old vocabulary fetcher. OSF_DICT = { "source-neuroquery_desc-gclda_features.csv": "trcxs", "source-neuroquery_desc-gclda_classification.csv": "93dvg", @@ -28,127 +33,448 @@ "source-neurosynth_desc-term_classification.csv": "sd4wy", } +# Public assets documented in the README. +OSF_ASSETS = { + "text_embeddings_braingpt_v02_body": { + "type": "file", + "file_id": "v748f", + "description": "Body text embeddings used in the paper.", + }, + "image_embeddings_difumo512_mkda": { + "type": "file", + "file_id": "nu2s7", + "description": "Normalized MKDA/DiFuMo image embeddings used in the paper.", + }, + "example_model_braingpt_v02_body": { + "type": "file", + "file_id": "u3cxh", + "description": "Example pretrained CLIP model.", + }, + "example_vocabulary_cogatlasred_task": { + "type": "file", + "file_id": "8m2fz", + "description": "Reduced Cognitive Atlas task vocabulary.", + }, + "example_vocabulary_embeddings_cogatlasred_task": { + "type": "file", + "file_id": "nza7b", + "description": "Example vocabulary embeddings for reduced CogAt tasks.", + }, + "example_vocabulary_prior_cogatlasred_task": { + "type": "file", + "file_id": "v82za", + "description": "Example vocabulary prior for reduced CogAt tasks.", + }, + "brain_mask_mni152_2mm": { + "type": "file", + "file_id": "jzvry", + "description": "Brain mask used in prediction examples.", + }, + "cognitive_atlas": { + "type": "folder", + "remote_path": "data/cognitive_atlas", + "description": "Cognitive Atlas snapshots and reduced task mapping.", + }, + "results_pubmed": { + "type": "folder", + "remote_path": "results/pubmed", + "description": "Published pretrained CLIP outputs from the paper.", + }, + "results_baseline": { + "type": "folder", + "remote_path": "results/baseline", + "description": "Published baseline decoder models from the paper.", + }, +} + +OSF_BUNDLES = { + "example_prediction": [ + "example_model_braingpt_v02_body", + "example_vocabulary_cogatlasred_task", + "example_vocabulary_embeddings_cogatlasred_task", + "example_vocabulary_prior_cogatlasred_task", + "brain_mask_mni152_2mm", + "cognitive_atlas", + ], + "training_embeddings": [ + "text_embeddings_braingpt_v02_body", + "image_embeddings_difumo512_mkda", + ], + "paper_results": [ + "results_pubmed", + "results_baseline", + ], + "all_readme_assets": [ + "text_embeddings_braingpt_v02_body", + "image_embeddings_difumo512_mkda", + "example_model_braingpt_v02_body", + "example_vocabulary_cogatlasred_task", + "example_vocabulary_embeddings_cogatlasred_task", + "example_vocabulary_prior_cogatlasred_task", + "brain_mask_mni152_2mm", + "cognitive_atlas", + "results_pubmed", + "results_baseline", + ], +} + + +def get_data_dir(data_dir=None): + """Return the default braindec data directory without importing heavy modules.""" + if data_dir is None: + data_dir = os.environ.get("BRAINDEC_DATA", os.path.join("~", "braindec-data")) + data_dir = os.path.expanduser(data_dir) + os.makedirs(data_dir, exist_ok=True) + return data_dir + + +def _should_retry_request(error): + response = getattr(error, "response", None) + if response is None: + return True + if response.status_code == 403 and "osf" in response.url: + return True + return response.status_code >= 500 + + +def _request_json(url, params=None, timeout=60, max_retries=DEFAULT_MAX_RETRIES): + last_error = None + for attempt in range(max_retries): + try: + response = requests.get(url, params=params, timeout=timeout) + response.raise_for_status() + return response.json() + except requests.RequestException as error: + last_error = error + if attempt == max_retries - 1 or not _should_retry_request(error): + raise + time.sleep(DEFAULT_RETRY_BACKOFF_SECONDS**attempt) + + raise last_error + + +def _download_to_file( + url, + destination, + overwrite=False, + timeout=60, + chunk_size=CHUNK_SIZE, + max_retries=DEFAULT_MAX_RETRIES, +): + destination = Path(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + if destination.exists() and not overwrite: + return destination + + tmp_destination = destination.with_suffix(destination.suffix + ".part") + last_error = None + for attempt in range(max_retries): + try: + with requests.get(url, stream=True, timeout=timeout) as response: + response.raise_for_status() + with tmp_destination.open("wb") as file_obj: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: + file_obj.write(chunk) + break + except requests.RequestException as error: + last_error = error + if tmp_destination.exists(): + tmp_destination.unlink() + if attempt == max_retries - 1 or not _should_retry_request(error): + raise + time.sleep(DEFAULT_RETRY_BACKOFF_SECONDS**attempt) + else: + raise last_error + + tmp_destination.replace(destination) + return destination + + +def _normalize_remote_path(remote_path): + remote_path = remote_path.strip("/") + if not remote_path: + return "/" + return f"/{remote_path}/" + + +def _materialized_path_to_local_path(materialized_path, destination_root): + relative_path = materialized_path.lstrip("/") + if relative_path.endswith("/"): + relative_path = relative_path[:-1] + return Path(destination_root) / relative_path + def _get_osf_url(filename): osf_id = OSF_DICT[filename] return OSF_URL.format(osf_id) -def _mk_tmpdir(data_dir, file, url): - """Make a temporary directory for fetching.""" - files_pickle = pickle.dumps([(file, url)]) - files_md5 = hashlib.md5(files_pickle).hexdigest() - temp_dir = os.path.join(data_dir, files_md5) +def _get_osf_file_metadata(file_id, timeout=60): + return _request_json(f"{OSF_API_BASE}/files/{file_id}/", timeout=timeout)["data"] - if not os.path.exists(temp_dir): - os.mkdir(temp_dir) - return temp_dir +def _iter_children(node_id=DEFAULT_OSF_NODE, folder_id=None, provider=DEFAULT_PROVIDER, timeout=60): + if folder_id is None: + url = f"{OSF_API_BASE}/nodes/{node_id}/files/{provider}/" + else: + url = f"{OSF_API_BASE}/nodes/{node_id}/files/{provider}/{folder_id}/" + + while url: + payload = _request_json(url, timeout=timeout) + for item in payload["data"]: + yield item + url = payload["links"].get("next") + + +def _get_folder_item(node_id, remote_path, provider=DEFAULT_PROVIDER, timeout=60): + normalized_path = _normalize_remote_path(remote_path) + if normalized_path == "/": + return None + + folder_id = None + parts = [part for part in normalized_path.strip("/").split("/") if part] + materialized_path = "/" + for part in parts: + children = list(_iter_children(node_id=node_id, folder_id=folder_id, provider=provider, timeout=timeout)) + match = None + for child in children: + attrs = child["attributes"] + if attrs["kind"] == "folder" and attrs["name"] == part: + match = child + break + if match is None: + raise FileNotFoundError(f"Remote OSF folder {remote_path!r} was not found in node {node_id}.") + folder_id = match["id"] + materialized_path = match["attributes"]["materialized_path"] + + return { + "id": folder_id, + "materialized_path": materialized_path, + } + + +def _get_remote_item(node_id, remote_path, provider=DEFAULT_PROVIDER, timeout=60): + normalized_path = remote_path.strip("/") + if not normalized_path: + return None + folder_id = None + current_item = None + parts = [part for part in normalized_path.split("/") if part] + for idx, part in enumerate(parts): + children = list(_iter_children(node_id=node_id, folder_id=folder_id, provider=provider, timeout=timeout)) + current_item = None + for child in children: + if child["attributes"]["name"] == part: + current_item = child + break + + if current_item is None: + raise FileNotFoundError(f"Remote OSF path {remote_path!r} was not found in node {node_id}.") + + is_last = idx == len(parts) - 1 + kind = current_item["attributes"]["kind"] + if not is_last: + if kind != "folder": + raise FileNotFoundError( + f"Remote OSF path {remote_path!r} traversed through non-folder component {part!r}." + ) + folder_id = current_item["id"] + + return current_item + + +def list_remote_assets(node_id=DEFAULT_OSF_NODE, remote_path="/", provider=DEFAULT_PROVIDER, timeout=60): + """List files and folders under an OSF path.""" + folder = _get_folder_item(node_id, remote_path, provider=provider, timeout=timeout) + folder_id = None if folder is None else folder["id"] + return list(_iter_children(node_id=node_id, folder_id=folder_id, provider=provider, timeout=timeout)) + + +def download_osf_file( + file_id, + destination_root=".", + overwrite=False, + use_materialized_path=True, + destination=None, + timeout=60, +): + """Download a single OSF file by id.""" + file_data = _get_osf_file_metadata(file_id, timeout=timeout) + attrs = file_data["attributes"] + download_url = file_data["links"]["download"] + + if destination is None: + if use_materialized_path: + destination = _materialized_path_to_local_path(attrs["materialized_path"], destination_root) + else: + destination = Path(destination_root) / attrs["name"] -def _my_fetch_file(data_dir, filename, url, overwrite=False, resume=True, verbose=1): - """Fetch a file from OSF.""" - path_name = os.path.join(data_dir, filename) - if not os.path.exists(path_name) or overwrite: - # Fetch file - tmpdir = _mk_tmpdir(data_dir, filename, url) - temp_fn = fetch_single_file(url, tmpdir, resume=resume, verbose=verbose) + destination = Path(destination) + return _download_to_file(download_url, destination, overwrite=overwrite, timeout=timeout) - # Move and delete tmpdir - shutil.move(temp_fn, path_name) - shutil.rmtree(tmpdir) - return path_name +def download_osf_folder( + remote_path, + destination_root=".", + node_id=DEFAULT_OSF_NODE, + provider=DEFAULT_PROVIDER, + overwrite=False, + timeout=60, +): + """Download all files under a folder path from the published OSF project.""" + folder = _get_folder_item(node_id, remote_path, provider=provider, timeout=timeout) + downloaded = [] + queue = [folder["id"]] + + while queue: + folder_id = queue.pop(0) + for item in _iter_children(node_id=node_id, folder_id=folder_id, provider=provider, timeout=timeout): + attrs = item["attributes"] + if attrs["kind"] == "folder": + queue.append(item["id"]) + continue + + destination = _materialized_path_to_local_path(attrs["materialized_path"], destination_root) + downloaded.append( + _download_to_file( + item["links"]["download"], + destination, + overwrite=overwrite, + timeout=timeout, + ) + ) + + return downloaded + + +def download_osf_path( + remote_path, + destination_root=".", + node_id=DEFAULT_OSF_NODE, + provider=DEFAULT_PROVIDER, + overwrite=False, + timeout=60, +): + """Download a published OSF file or folder by its remote path.""" + item = _get_remote_item(node_id=node_id, remote_path=remote_path, provider=provider, timeout=timeout) + if item["attributes"]["kind"] == "folder": + return download_osf_folder( + remote_path, + destination_root=destination_root, + node_id=node_id, + provider=provider, + overwrite=overwrite, + timeout=timeout, + ) + + destination = _materialized_path_to_local_path(item["attributes"]["materialized_path"], destination_root) + return [ + _download_to_file( + item["links"]["download"], + destination, + overwrite=overwrite, + timeout=timeout, + ) + ] + + +def get_available_assets(): + """Return the names of downloadable assets and bundles.""" + return { + "assets": sorted(OSF_ASSETS), + "bundles": sorted(OSF_BUNDLES), + } + + +def download_asset(name, destination_root=".", overwrite=False, node_id=DEFAULT_OSF_NODE, timeout=60): + """Download a named asset from the built-in manifest.""" + if name not in OSF_ASSETS: + raise KeyError(f"Unknown asset {name!r}. Available assets: {sorted(OSF_ASSETS)}") + + asset = OSF_ASSETS[name] + if asset["type"] == "file": + return [download_osf_file(asset["file_id"], destination_root=destination_root, overwrite=overwrite, timeout=timeout)] + if asset["type"] == "folder": + return download_osf_folder( + asset["remote_path"], + destination_root=destination_root, + node_id=node_id, + overwrite=overwrite, + timeout=timeout, + ) + + raise ValueError(f"Unsupported asset type {asset['type']!r}.") + + +def download_bundle(name, destination_root=".", overwrite=False, node_id=DEFAULT_OSF_NODE, timeout=60): + """Download a predefined bundle of assets.""" + if name not in OSF_BUNDLES: + raise KeyError(f"Unknown bundle {name!r}. Available bundles: {sorted(OSF_BUNDLES)}") + + downloaded = [] + for asset_name in OSF_BUNDLES[name]: + downloaded.extend( + download_asset( + asset_name, + destination_root=destination_root, + overwrite=overwrite, + node_id=node_id, + timeout=timeout, + ) + ) + return downloaded def _fetch_vocabulary( source="neurosynth", - subsample=["Functional"], + subsample=None, data_dir=None, overwrite=False, - resume=True, verbose=1, ): - """Fetch features from OSF. - - Parameters - ---------- - source : :obj:`str` - Name of dataset. - model_nm : :obj:`str` - Name of model. - data_dir : :obj:`pathlib.Path` or :obj:`str`, optional - Path where data should be downloaded. By default, - files are downloaded in home directory - resume : :obj:`bool`, optional - Whether to resume download of a partly-downloaded file. - Default=True. - verbose : :obj:`int`, optional - Verbosity level (0 means no message). - Default=1. - - Returns - ------- - :class:`list` of str - List of feature names. - """ + """Fetch legacy term features/classifications from OSF and return the vocabulary.""" + subsample = ["Functional"] if subsample is None else subsample data_dir = get_data_dir(data_dir) vocabulary_dir = get_data_dir(os.path.join(data_dir, "vocabulary")) filename = f"source-{source}_desc-term_features.csv" - url = _get_osf_url(filename) - - features_fn = _my_fetch_file( - vocabulary_dir, - filename, - url, + features_fn = _download_to_file( + _get_osf_url(filename), + Path(vocabulary_dir) / filename, overwrite=overwrite, - resume=resume, - verbose=verbose, ) + del verbose # preserved for backward compatibility df = pd.read_csv(features_fn) filename_classification = f"source-{source}_desc-term_classification.csv" - url_classification = _get_osf_url(filename_classification) - - classification_fn = _my_fetch_file( - vocabulary_dir, - filename_classification, - url_classification, + classification_fn = _download_to_file( + _get_osf_url(filename_classification), + Path(vocabulary_dir) / filename_classification, overwrite=overwrite, - resume=resume, - verbose=verbose, ) classification_df = pd.read_csv(classification_fn, index_col="Classification") classification = classification_df.index.tolist() - - keep = np.array([c_i for c_i, class_ in enumerate(classification) if class_ in subsample]) + keep = [index for index, class_name in enumerate(classification) if class_name in subsample] return df.values[keep].flatten().tolist() def _get_cogatlas_data(url): try: - # Send a GET request to the API - response = requests.get(url) - - # Raise an exception for bad responses + response = requests.get(url, timeout=60) response.raise_for_status() - - # Parse the JSON response into a Python dictionary tasks = response.json() - - except requests.RequestException as e: - print(f"Error retrieving tasks: {e}") + except requests.RequestException as error: + print(f"Error retrieving tasks: {error}") return None output = {} for task in tasks: - if ("name" in task) and (task["name"] != "") and ("definition_text" in task): + if ("name" in task) and task["name"] and ("definition_text" in task): output[task["name"]] = task["definition_text"] - # elif "name" in task and "definition_text" not in task: - # output[task["name"]] = "" else: print(f"Task {task} does not have a name or definition_text") @@ -156,10 +482,109 @@ def _get_cogatlas_data(url): def get_cogatlas_tasks(): - # API endpoint for tasks + """Fetch task definitions from the Cognitive Atlas API.""" return _get_cogatlas_data("https://www.cognitiveatlas.org/api/v-alpha/task") def get_cogatlas_concepts(): - # API endpoint for concepts + """Fetch concept definitions from the Cognitive Atlas API.""" return _get_cogatlas_data("https://www.cognitiveatlas.org/api/v-alpha/concept") + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Download published braindec assets from OSF") + parser.add_argument( + "--destination_root", + dest="destination_root", + default=".", + help="Root directory under which OSF materialized paths will be recreated.", + ) + parser.add_argument( + "--asset", + dest="assets", + nargs="+", + default=None, + help="One or more named assets to download.", + ) + parser.add_argument( + "--bundle", + dest="bundles", + nargs="+", + default=None, + help="One or more predefined bundles to download.", + ) + parser.add_argument( + "--folder", + dest="folders", + nargs="+", + default=None, + help="One or more raw OSF folder paths to download, for example data/cognitive_atlas.", + ) + parser.add_argument( + "--list", + dest="list_only", + action="store_true", + help="Print available built-in assets and bundles.", + ) + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrite existing local files.", + ) + return parser + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + + if options.list_only: + available = get_available_assets() + print("Assets:") + for asset in available["assets"]: + print(f" - {asset}") + print("Bundles:") + for bundle in available["bundles"]: + print(f" - {bundle}") + return + + downloaded = [] + if options.assets: + for asset in options.assets: + downloaded.extend( + download_asset( + asset, + destination_root=options.destination_root, + overwrite=options.overwrite, + ) + ) + + if options.bundles: + for bundle in options.bundles: + downloaded.extend( + download_bundle( + bundle, + destination_root=options.destination_root, + overwrite=options.overwrite, + ) + ) + + if options.folders: + for folder in options.folders: + downloaded.extend( + download_osf_folder( + folder, + destination_root=options.destination_root, + overwrite=options.overwrite, + ) + ) + + if not (options.assets or options.bundles or options.folders or options.list_only): + raise SystemExit("Select at least one of --asset, --bundle, --folder, or --list.") + + for path in downloaded: + print(path) + + +if __name__ == "__main__": + _main() diff --git a/braindec/predict.py b/braindec/predict.py index 3598fcf..b3a4682 100644 --- a/braindec/predict.py +++ b/braindec/predict.py @@ -7,16 +7,22 @@ import numpy as np import pandas as pd import torch -from nilearn._utils.niimg_conversions import check_same_fov from nilearn.image import load_img, resample_to_img from braindec.cogatlas import CognitiveAtlas from braindec.embedding import ImageEmbedding from braindec.model import build_model -from braindec.utils import _get_device, _read_vocabulary, get_data_dir +from braindec.utils import _get_device, _read_vocabulary, get_data_dir, images_have_same_fov -def preprocess_image(image, standardize=False, data_dir=None, space="MNI152", density=None): +def preprocess_image( + image, + standardize=False, + data_dir=None, + space="MNI152", + density=None, + image_emb_gene=None, +): """ Preprocess the image. @@ -26,12 +32,13 @@ def preprocess_image(image, standardize=False, data_dir=None, space="MNI152", de data_dir = get_data_dir(data_dir) nilearn_dir = op.join(data_dir, "nilearn") - image_emb_gene = ImageEmbedding( - standardize=standardize, - nilearn_dir=nilearn_dir, - space=space, - density=density, - ) + if image_emb_gene is None: + image_emb_gene = ImageEmbedding( + standardize=standardize, + nilearn_dir=nilearn_dir, + space=space, + density=density, + ) image_embedding_arr = image_emb_gene(image) return torch.from_numpy(image_embedding_arr).float() @@ -47,6 +54,7 @@ def image_to_labels( logit_scale=None, return_posterior_probability=False, device=None, + model=None, **kwargs, ): """Predict the labels of an image using a pre-trained model.""" @@ -65,7 +73,7 @@ def image_to_labels( image_input = image_input / (image_input.norm(dim=-1, keepdim=True) + 1e-8) # Calculate features - model = build_model(model_path, device=device) + model = build_model(model_path, device=device) if model is None else model with torch.no_grad(): image_features, text_features = model(image_input, text_inputs) # normalized @@ -111,7 +119,9 @@ def image_to_labels( "bayes_factor": bayes_factor[top_indices], } ) - return task_prob_df, posterior_probability if return_posterior_probability else task_prob_df + if return_posterior_probability: + return task_prob_df, posterior_probability + return task_prob_df def image_to_labels_hierarchical( @@ -124,6 +134,7 @@ def image_to_labels_hierarchical( topk=10, logit_scale=None, device=None, + model=None, **kwargs, ): """Predict the label of an image.""" @@ -137,6 +148,7 @@ def image_to_labels_hierarchical( logit_scale=logit_scale, return_posterior_probability=True, device=device, + model=model, **kwargs, ) @@ -298,7 +310,7 @@ def _main(argv=None): img = nib.load(image_fn) mask_img = nib.load(mask_fn) - if not check_same_fov(img, reference_masker=mask_img): + if not images_have_same_fov(img, mask_img): img = resample_to_img(img, mask_img) vocabulary, vocabulary_emb, vocabulary_prior = _read_vocabulary( diff --git a/braindec/utils.py b/braindec/utils.py index 47e3982..7510601 100644 --- a/braindec/utils.py +++ b/braindec/utils.py @@ -7,10 +7,7 @@ import numpy as np import pandas as pd import torch -from neuromaps import transforms -from neuromaps.datasets import fetch_atlas from nibabel.gifti import GiftiDataArray -from nilearn.surface import PolyMesh, SurfaceImage, load_surf_mesh from nimare.utils import get_resource_path from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer @@ -98,6 +95,11 @@ def _get_device(): return torch.device("cpu") # Default to CPU +def images_have_same_fov(img, reference_img): + """Return whether two Niimg-like objects share shape and affine.""" + return img.shape[:3] == reference_img.shape[:3] and np.allclose(img.affine, reference_img.affine) + + def _zero_medial_wall( data_lh, data_rh, @@ -107,6 +109,8 @@ def _zero_medial_wall( neuromaps_dir=None, ): """Remove medial wall from data in fsLR space.""" + from neuromaps.datasets import fetch_atlas + atlas = fetch_atlas(space, density, data_dir=neuromaps_dir, verbose=0) medial_lh, medial_rh = atlas["medial"] @@ -175,6 +179,8 @@ def _rm_medial_wall( `data` has the incorrect number of vertices (59412 or 64984 only accepted) """ + from neuromaps.datasets import fetch_atlas + assert data_lh.shape[0] == N_VERTICES_PH[space][density] assert data_rh.shape[0] == N_VERTICES_PH[space][density] @@ -205,6 +211,8 @@ def _vol_to_surf( neuromaps_dir=None, ): """Transform 4D metamaps from volume to surface space.""" + from neuromaps import transforms + if space == "fsLR": metamap_lh, metamap_rh = transforms.mni152_to_fslr(metamap, fslr_density=density) elif space == "fsaverage": @@ -245,6 +253,8 @@ def _vol_surfimg( density="32k", neuromaps_dir=None, ): + from nilearn.surface import PolyMesh, SurfaceImage, load_surf_mesh + lh_data, rh_data, atlas = _vol_to_surf( vol, space=space, diff --git a/docs/_static/.gitkeep b/docs/_static/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/_static/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..d7b54a1 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,40 @@ +API Reference +============= + +Core prediction +--------------- + +.. automodule:: braindec.predict + :members: image_to_labels, image_to_labels_hierarchical, preprocess_image + +Embeddings +---------- + +.. automodule:: braindec.embedding + :members: ImageEmbedding, TextEmbedding + +Cognitive Atlas +--------------- + +.. automodule:: braindec.cogatlas + :members: CognitiveAtlas + +Data fetching +------------- + +.. automodule:: braindec.fetcher + :members: download_bundle, download_asset, download_osf_file, + download_osf_folder, get_available_assets, + get_cogatlas_tasks, get_cogatlas_concepts + +Model +----- + +.. automodule:: braindec.model + :members: CLIP, build_model + +Utilities +--------- + +.. automodule:: braindec.utils + :members: get_data_dir, images_have_same_fov diff --git a/docs/auto_examples/02_niclip_demo.ipynb b/docs/auto_examples/02_niclip_demo.ipynb new file mode 100644 index 0000000..b95e364 --- /dev/null +++ b/docs/auto_examples/02_niclip_demo.ipynb @@ -0,0 +1,971 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "61da1140", + "metadata": {}, + "outputs": [], + "source": [ + "# Install braindec (this cell is only needed on Google Colab).\n", + "%pip install \"braindec[plotting] @ git+https://github.com/jdkent/brain-decoder.git\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56e5f3c3", + "metadata": {}, + "outputs": [], + "source": [ + "# Display Matplotlib figures inline in notebooks.\n", + "%matplotlib inline\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9822f39f", + "metadata": {}, + "outputs": [], + "source": [ + "r\"\"\"\n", + "NiCLIP: Functional Brain Decoding Tutorial\n", + "===========================================\n", + "\n", + "`NiCLIP `_ is a contrastive\n", + "language–image pre-training (CLIP) model trained on ~23,000 neuroimaging\n", + "articles that maps brain activation patterns to cognitive task descriptions\n", + "from the `Cognitive Atlas `_ ontology.\n", + "\n", + "This tutorial walks through the main use cases:\n", + "\n", + "1. **Group-level task decoding** — predict tasks, concepts, and cognitive\n", + " process domains from a group-level activation map.\n", + "2. **Hierarchical decoding** — obtain predictions at three ontology levels\n", + " (tasks → concepts → domains) using the noisy-OR propagation rule.\n", + "3. **Brain region characterization** — characterize anatomical ROIs\n", + " without pre-computed meta-analytic maps.\n", + "4. **Subject-level decoding** — apply decoding to noisier single-subject maps.\n", + "5. **Custom vocabulary** — decode against a user-supplied task vocabulary.\n", + "6. **Latent space exploration** — visualize the shared image–text embedding\n", + " space learned by NiCLIP.\n", + "\n", + ".. note::\n", + "\n", + " **Run this tutorial on Google Colab**\n", + "\n", + " .. image:: https://colab.research.google.com/assets/colab-badge.svg\n", + " :target: https://colab.research.google.com/github/jdkent/brain-decoder/blob/main/docs/auto_examples/02_niclip_demo.ipynb\n", + " :alt: Open In Colab\n", + "\n", + " The first notebook cell installs :code:`braindec` and its dependencies\n", + " automatically. The full install takes a few minutes on a fresh Colab\n", + " runtime; subsequent runs reuse the cached packages.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8770834", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 0: Download example assets\n", + "# ------------------------------------\n", + "# NiCLIP ships a curated set of publishable assets on\n", + "# `OSF `_. The ``example_prediction`` bundle\n", + "# contains the pre-trained CLIP model, reduced Cognitive Atlas vocabulary,\n", + "# pre-computed vocabulary embeddings, vocabulary prior, brain mask, and\n", + "# Cognitive Atlas ontology snapshots — everything required to run decoding.\n", + "#\n", + "# The download is skipped automatically for files that already exist locally.\n", + "\n", + "import os\n", + "import os.path as op\n", + "from pathlib import Path\n", + "\n", + "import nibabel as nib\n", + "import numpy as np\n", + "import pandas as pd\n", + "import requests\n", + "\n", + "from braindec.fetcher import download_bundle, get_data_dir\n", + "\n", + "work_dir = get_data_dir()\n", + "print(f\"Working directory: {work_dir}\")\n", + "\n", + "downloaded = download_bundle(\"example_prediction\", destination_root=work_dir)\n", + "print(f\"Bundle contains {len(downloaded)} files\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcaf38aa", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct paths to downloaded assets.\n", + "# These mirror the OSF folder structure that ``download_bundle`` preserves.\n", + "\n", + "MODEL_NAME = \"BrainGPT-7B-v0.2\"\n", + "SECTION = \"body\"\n", + "SOURCE = \"cogatlasred\"\n", + "VOC_LABEL = f\"vocabulary-{SOURCE}_task-combined_embedding-{MODEL_NAME}\"\n", + "\n", + "data_dir = op.join(work_dir, \"data\")\n", + "results_dir = op.join(work_dir, \"results\")\n", + "voc_dir = op.join(data_dir, \"vocabulary\")\n", + "cog_atlas_dir = op.join(data_dir, \"cognitive_atlas\")\n", + "\n", + "model_fn = op.join(results_dir, \"pubmed\",\n", + " f\"model-clip_section-{SECTION}_embedding-{MODEL_NAME}_best.pth\")\n", + "vocabulary_fn = op.join(voc_dir, f\"vocabulary-{SOURCE}_task.txt\")\n", + "vocabulary_emb_fn = op.join(voc_dir, f\"{VOC_LABEL}.npy\")\n", + "vocabulary_prior_fn = op.join(voc_dir, f\"{VOC_LABEL}_section-{SECTION}_prior.npy\")\n", + "mask_fn = op.join(data_dir, \"MNI152_2x2x2_brainmask.nii.gz\")\n", + "\n", + "for label, path in [\n", + " (\"model\", model_fn),\n", + " (\"vocabulary\", vocabulary_fn),\n", + " (\"vocabulary embeddings\", vocabulary_emb_fn),\n", + " (\"vocabulary prior\", vocabulary_prior_fn),\n", + " (\"brain mask\", mask_fn),\n", + " (\"cognitive atlas\", cog_atlas_dir),\n", + "]:\n", + " status = \"✓\" if op.exists(path) else \"✗ MISSING\"\n", + " print(f\" {status} {label}: {path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "847dfe91", + "metadata": {}, + "outputs": [], + "source": [ + "# Download representative HCP group-level contrast maps from\n", + "# `NeuroVault `_ (public access,\n", + "# no account required). We use three contrasts that span different cognitive\n", + "# domains to illustrate decoding across sections 1–6.\n", + "\n", + "HCP_COLLECTION_ID = 457\n", + "HCP_MAPS = {\n", + " \"motor\": \"tfMRI_MOTOR_AVG_zstat1.nii.gz\",\n", + " \"language\": \"tfMRI_LANGUAGE_STORY-MATH_zstat1.nii.gz\",\n", + " \"emotion\": \"tfMRI_EMOTION_FACES-SHAPES_zstat1.nii.gz\",\n", + " \"working_memory\": \"tfMRI_WM_2BK-0BK_zstat1.nii.gz\",\n", + "}\n", + "\n", + "hcp_dir = Path(data_dir) / \"hcp\" / \"neurovault\"\n", + "hcp_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "hcp_paths = {}\n", + "for domain, filename in HCP_MAPS.items():\n", + " dest = hcp_dir / filename\n", + " if not dest.exists():\n", + " url = f\"https://neurovault.org/media/images/{HCP_COLLECTION_ID}/{filename}\"\n", + " print(f\"Downloading {domain} map …\")\n", + " with requests.get(url, stream=True, timeout=120) as r:\n", + " r.raise_for_status()\n", + " with open(dest, \"wb\") as fh:\n", + " for chunk in r.iter_content(chunk_size=1024 * 1024):\n", + " fh.write(chunk)\n", + " hcp_paths[domain] = str(dest)\n", + " print(f\" {domain}: {dest}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1130d6eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise shared resources that will be reused across sections.\n", + "# Building the model and image embedder once avoids repeated I/O and\n", + "# DiFuMo atlas downloads.\n", + "\n", + "import torch\n", + "from braindec.cogatlas import CognitiveAtlas\n", + "from braindec.embedding import ImageEmbedding\n", + "from braindec.model import build_model\n", + "from braindec.utils import _get_device\n", + "\n", + "device = _get_device()\n", + "print(f\"Using device: {device}\")\n", + "\n", + "model = build_model(model_fn, device=device)\n", + "\n", + "image_embedder = ImageEmbedding(\n", + " standardize=False,\n", + " nilearn_dir=op.join(data_dir, \"nilearn\"),\n", + " space=\"MNI152\",\n", + ")\n", + "\n", + "with open(vocabulary_fn) as fh:\n", + " vocabulary = [line.strip() for line in fh]\n", + "vocabulary_emb = np.load(vocabulary_emb_fn)\n", + "vocabulary_prior = np.load(vocabulary_prior_fn)\n", + "\n", + "print(f\"Vocabulary size: {len(vocabulary)} tasks\")\n", + "print(f\"Embedding shape: {vocabulary_emb.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e69948b", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 1: Group-level task decoding\n", + "# ----------------------------------------\n", + "# The primary NiCLIP use case is *functional decoding*: given a brain\n", + "# activation map, retrieve the most likely cognitive tasks from the Cognitive\n", + "# Atlas. NiCLIP computes posterior probabilities P(T|A) using Bayes'\n", + "# theorem over the CLIP cosine similarities.\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from nilearn.plotting import plot_stat_map\n", + "\n", + "from braindec.predict import image_to_labels\n", + "\n", + "motor_img = nib.load(hcp_paths[\"motor\"])\n", + "\n", + "task_df = image_to_labels(\n", + " motor_img,\n", + " model_path=model_fn,\n", + " vocabulary=vocabulary,\n", + " vocabulary_emb=vocabulary_emb,\n", + " prior_probability=vocabulary_prior,\n", + " topk=10,\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + ")\n", + "\n", + "print(\"Top-10 task predictions for HCP Motor (AVG) contrast:\")\n", + "print(task_df.to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3565bb9e", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualise the input activation map and the top-5 task predictions.\n", + "\n", + "plot_stat_map(\n", + " motor_img,\n", + " display_mode=\"z\",\n", + " cut_coords=5,\n", + " colorbar=True,\n", + " threshold=2.0,\n", + " title=\"HCP Motor (AVG) z-stat\",\n", + ")\n", + "plt.show()\n", + "\n", + "top5 = task_df.head(5)\n", + "short_labels = [t[:40] + \"…\" if len(t) > 40 else t for t in top5[\"pred\"]]\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.barh(range(len(top5)), top5[\"prob\"], color=\"steelblue\")\n", + "ax.set_yticks(range(len(top5)))\n", + "ax.set_yticklabels(short_labels, fontsize=9)\n", + "ax.invert_yaxis()\n", + "ax.set_xlabel(\"Posterior probability P(T|A)\")\n", + "ax.set_title(\"Top-5 task predictions\")\n", + "ax.set_xlim(0, top5[\"prob\"].max() * 1.2)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df6f5531", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 2: Hierarchical decoding\n", + "# -------------------------------------\n", + "# NiCLIP propagates task posteriors up the Cognitive Atlas ontology using\n", + "# a noisy-OR model to derive concept and cognitive process domain\n", + "# probabilities: P(C|A) and P(D|A). This produces interpretations at\n", + "# three levels of specificity.\n", + "\n", + "import json\n", + "\n", + "from braindec.predict import image_to_labels_hierarchical\n", + "\n", + "concept_to_process_fn = op.join(cog_atlas_dir, \"concept_to_process.json\")\n", + "with open(concept_to_process_fn) as fh:\n", + " concept_to_process = json.load(fh)\n", + "\n", + "reduced_tasks_df = pd.read_csv(op.join(cog_atlas_dir, \"reduced_tasks.csv\"))\n", + "\n", + "cog_atlas = CognitiveAtlas(\n", + " data_dir=data_dir,\n", + " task_snapshot=op.join(cog_atlas_dir, \"task_snapshot-02-19-25.json\"),\n", + " concept_snapshot=op.join(cog_atlas_dir, \"concept_extended_snapshot-02-19-25.json\"),\n", + " concept_to_process=concept_to_process,\n", + " reduced_tasks=reduced_tasks_df,\n", + ")\n", + "\n", + "print(f\"Cognitive Atlas: {len(cog_atlas.task_names)} tasks | \"\n", + " f\"{len(cog_atlas.concept_names)} concepts | \"\n", + " f\"{len(cog_atlas.process_names)} domains\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac644cf8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "task_df_h, concept_df, domain_df = image_to_labels_hierarchical(\n", + " motor_img,\n", + " model_path=model_fn,\n", + " vocabulary=vocabulary,\n", + " vocabulary_emb=vocabulary_emb,\n", + " prior_probability=vocabulary_prior,\n", + " cognitiveatlas=cog_atlas,\n", + " topk=5,\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d485260a", + "metadata": {}, + "outputs": [], + "source": [ + "# Display predictions at all three ontology levels.\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n", + "panels = [\n", + " (task_df_h, \"Tasks P(T|A)\", \"prob\"),\n", + " (concept_df, \"Concepts P(C|A)\", \"prob\"),\n", + " (domain_df, \"Domains P(D|A)\", \"prob\"),\n", + "]\n", + "\n", + "for ax, (df, title, col) in zip(axes, panels):\n", + " df_top = df.head(5)\n", + " labels = [t[:35] + \"…\" if len(t) > 35 else t for t in df_top[\"pred\"]]\n", + " ax.barh(range(len(df_top)), df_top[col], color=\"steelblue\")\n", + " ax.set_yticks(range(len(df_top)))\n", + " ax.set_yticklabels(labels, fontsize=8)\n", + " ax.invert_yaxis()\n", + " ax.set_xlabel(\"Posterior probability\")\n", + " ax.set_title(title, fontsize=10)\n", + " ax.set_xlim(0, df_top[col].max() * 1.3)\n", + "\n", + "fig.suptitle(\"HCP Motor (AVG) — Hierarchical decoding\", fontsize=12)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de188bc3", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 3: Brain region characterization\n", + "# -------------------------------------------\n", + "# Instead of a task activation map, NiCLIP can decode *anatomical ROIs*\n", + "# directly — enabling functional characterisation of brain regions without\n", + "# requiring pre-computed meta-analytic maps. This supports hypothesis\n", + "# generation about the cognitive roles of unstudied regions.\n", + "#\n", + "# We create binary ROI masks from the Harvard-Oxford atlas available via\n", + "# nilearn. The subcortical atlas provides named amygdala, hippocampus, and\n", + "# striatal regions. For cortical examples not represented in this atlas, we\n", + "# add simple spherical MNI ROIs so the tutorial still covers a broader set of\n", + "# functional regions.\n", + "\n", + "from nilearn import datasets, image as nli_image\n", + "\n", + "mask_img = nib.load(mask_fn)\n", + "mask_data = mask_img.get_fdata() > 0\n", + "\n", + "\n", + "def make_label_roi(atlas_img, atlas_labels, region_names):\n", + " \"\"\"Create a binary ROI from one or more deterministic atlas label names.\"\"\"\n", + " atlas_data = atlas_img.get_fdata()\n", + " label_to_idx = {name: idx for idx, name in enumerate(atlas_labels)}\n", + " roi = np.zeros(atlas_data.shape, dtype=np.float32)\n", + " for name in region_names:\n", + " if name in label_to_idx:\n", + " roi[atlas_data == label_to_idx[name]] = 1.0\n", + " else:\n", + " print(f\" Warning: '{name}' not found in atlas.\")\n", + " return nib.Nifti1Image(roi, atlas_img.affine, atlas_img.header)\n", + "\n", + "\n", + "def make_spherical_roi(center_xyz, radius_mm=8):\n", + " \"\"\"Create a spherical ROI in MNI millimeter coordinates.\"\"\"\n", + " ijk = np.indices(mask_img.shape).reshape(3, -1).T\n", + " xyz = nib.affines.apply_affine(mask_img.affine, ijk)\n", + " distances = np.linalg.norm(xyz - np.asarray(center_xyz), axis=1)\n", + " roi = (distances <= radius_mm).reshape(mask_img.shape) & mask_data\n", + " return nib.Nifti1Image(roi.astype(np.float32), mask_img.affine, mask_img.header)\n", + "\n", + "\n", + "ROI_LABEL_SPECS = {\n", + " \"Amygdala\": [\"Left Amygdala\", \"Right Amygdala\"],\n", + " \"Hippocampus\": [\"Left Hippocampus\", \"Right Hippocampus\"],\n", + " \"Striatum\": [\"Left Putamen\", \"Right Putamen\", \"Left Caudate\", \"Right Caudate\"],\n", + "}\n", + "\n", + "ROI_COORD_SPECS = {\n", + " \"Amygdala\": (-22, -4, -18),\n", + " \"Hippocampus\": (-26, -20, -14),\n", + " \"Insula\": (-34, 18, 4),\n", + " \"Striatum\": (-18, 8, 4),\n", + " \"rTPJ\": (54, -54, 24),\n", + " \"vmPFC\": (0, 46, -8),\n", + "}\n", + "\n", + "try:\n", + " ho_sub = datasets.fetch_atlas_harvard_oxford(\n", + " \"sub-maxprob-thr25-2mm\",\n", + " data_dir=op.join(data_dir, \"nilearn\"),\n", + " verbose=1,\n", + " )\n", + " roi_images = {\n", + " name: make_label_roi(ho_sub.maps, ho_sub.labels, labels)\n", + " for name, labels in ROI_LABEL_SPECS.items()\n", + " }\n", + " roi_images.update({\n", + " name: make_spherical_roi(center)\n", + " for name, center in ROI_COORD_SPECS.items()\n", + " if name not in roi_images\n", + " })\n", + " print(\"Using Harvard-Oxford atlas ROIs with spherical cortical examples.\")\n", + "except Exception as exc:\n", + " print(f\"Harvard-Oxford atlas download failed ({type(exc).__name__}: {exc})\")\n", + " print(\"Using fallback spherical MNI ROIs for this tutorial run.\")\n", + " roi_images = {name: make_spherical_roi(center) for name, center in ROI_COORD_SPECS.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7afabaf", + "metadata": {}, + "outputs": [], + "source": [ + "# Decode each ROI and collect the top task, concept, and domain.\n", + "\n", + "roi_summary = []\n", + "for roi_name, roi_img in roi_images.items():\n", + " t_df, c_df, d_df = image_to_labels_hierarchical(\n", + " roi_img,\n", + " model_path=model_fn,\n", + " vocabulary=vocabulary,\n", + " vocabulary_emb=vocabulary_emb,\n", + " prior_probability=vocabulary_prior,\n", + " cognitiveatlas=cog_atlas,\n", + " topk=3,\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + " )\n", + " roi_summary.append({\n", + " \"ROI\": roi_name,\n", + " \"Top task\": t_df.iloc[0][\"pred\"],\n", + " \"Task P(T|A)\": f\"{t_df.iloc[0]['prob']:.3f}\",\n", + " \"Top concept\": c_df.iloc[0][\"pred\"],\n", + " \"Top domain\": d_df.iloc[0][\"pred\"],\n", + " })\n", + "\n", + "summary_df = pd.DataFrame(roi_summary)\n", + "print(summary_df.to_string(index=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbab8584", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualise one ROI alongside its top prediction.\n", + "\n", + "from nilearn.plotting import plot_roi\n", + "\n", + "plot_roi(\n", + " roi_images[\"Amygdala\"],\n", + " title=\"Amygdala (bilateral)\",\n", + " display_mode=\"ortho\",\n", + " cut_coords=(0, -4, -18),\n", + " colorbar=False,\n", + ")\n", + "plt.show()\n", + "\n", + "# Decode the amygdala with a finer top-k for the bar chart.\n", + "t_df, c_df, d_df = image_to_labels_hierarchical(\n", + " roi_images[\"Amygdala\"],\n", + " model_path=model_fn,\n", + " vocabulary=vocabulary,\n", + " vocabulary_emb=vocabulary_emb,\n", + " prior_probability=vocabulary_prior,\n", + " cognitiveatlas=cog_atlas,\n", + " topk=5,\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + ")\n", + "\n", + "labels = [t[:38] + \"…\" if len(t) > 38 else t for t in t_df[\"pred\"]]\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.barh(range(5), t_df[\"prob\"], color=\"salmon\")\n", + "ax.set_yticks(range(5))\n", + "ax.set_yticklabels(labels, fontsize=8)\n", + "ax.invert_yaxis()\n", + "ax.set_xlabel(\"P(T|A)\")\n", + "ax.set_title(\"Top-5 task predictions for Amygdala ROI\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b40fcb5", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 4: Subject-level decoding\n", + "# -------------------------------------\n", + "# NiCLIP can decode single-subject activation maps, though performance is\n", + "# lower than group-level due to higher noise. Here we simulate a\n", + "# subject-level map by adding Gaussian noise to the group-level motor\n", + "# contrast, then compare predicted ranks to the clean result.\n", + "#\n", + "# In practice, you would supply your own subject-level t-stat or z-stat\n", + "# NIfTI image in place of the simulated map below.\n", + "\n", + "motor_data = motor_img.get_fdata()\n", + "rng = np.random.default_rng(42)\n", + "noise_std = motor_data.std()\n", + "noisy_data = motor_data + rng.normal(scale=noise_std, size=motor_data.shape)\n", + "noisy_img = nib.Nifti1Image(noisy_data, motor_img.affine, motor_img.header)\n", + "\n", + "task_df_noisy = image_to_labels(\n", + " noisy_img,\n", + " model_path=model_fn,\n", + " vocabulary=vocabulary,\n", + " vocabulary_emb=vocabulary_emb,\n", + " prior_probability=vocabulary_prior,\n", + " topk=10,\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12b5cd40", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare predictions from the clean group map vs. the simulated subject map.\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "for ax, (df, title, color) in zip(\n", + " axes,\n", + " [\n", + " (task_df.head(5), \"Group-level (clean)\", \"steelblue\"),\n", + " (task_df_noisy.head(5), \"Subject-level (simulated noise)\", \"orange\"),\n", + " ],\n", + "):\n", + " labels = [t[:40] + \"…\" if len(t) > 40 else t for t in df[\"pred\"]]\n", + " ax.barh(range(len(df)), df[\"prob\"], color=color)\n", + " ax.set_yticks(range(len(df)))\n", + " ax.set_yticklabels(labels, fontsize=8)\n", + " ax.invert_yaxis()\n", + " ax.set_xlabel(\"P(T|A)\")\n", + " ax.set_title(title)\n", + "\n", + "fig.suptitle(\"Motor decoding: group vs. subject-level noise\", fontsize=12)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2b2e1fc", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 5: Custom vocabulary decoding\n", + "# -----------------------------------------\n", + "# NiCLIP accepts any list of task names paired with their LLM-derived text\n", + "# embeddings. This lets you decode against a domain-specific vocabulary\n", + "# instead of (or in addition to) the full Cognitive Atlas.\n", + "#\n", + "# **Two workflows:**\n", + "#\n", + "# *Workflow A — subset the existing vocabulary.*\n", + "# Select a subset of Cognitive Atlas tasks relevant to your study domain\n", + "# and decode with just those terms. No additional embedding needed.\n", + "#\n", + "# *Workflow B — embed entirely new task names.*\n", + "# Use :class:`~braindec.embedding.TextEmbedding` with BrainGPT to embed\n", + "# custom task descriptions, then pass them directly to\n", + "# :func:`~braindec.predict.image_to_labels`.\n", + "\n", + "# Workflow A: emotion-focused vocabulary subset\n", + "EMOTION_KEYWORDS = [\"emotion\", \"fear\", \"affect\", \"face\", \"amygdala\", \"valence\", \"threat\"]\n", + "\n", + "custom_idx = [\n", + " i for i, task in enumerate(vocabulary)\n", + " if any(kw in task.lower() for kw in EMOTION_KEYWORDS)\n", + "]\n", + "custom_vocabulary = [vocabulary[i] for i in custom_idx]\n", + "custom_vocabulary_emb = vocabulary_emb[custom_idx]\n", + "custom_prior = vocabulary_prior[custom_idx]\n", + "# Re-normalise prior so probabilities sum to 1.\n", + "custom_prior = custom_prior / custom_prior.sum()\n", + "\n", + "print(f\"Custom emotion vocabulary: {len(custom_vocabulary)} tasks\")\n", + "print(\" \" + \"\\n \".join(custom_vocabulary[:8]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7a04e00", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "emotion_img = nib.load(hcp_paths[\"emotion\"])\n", + "\n", + "task_df_custom = image_to_labels(\n", + " emotion_img,\n", + " model_path=model_fn,\n", + " vocabulary=custom_vocabulary,\n", + " vocabulary_emb=custom_vocabulary_emb,\n", + " prior_probability=custom_prior,\n", + " topk=min(8, len(custom_vocabulary)),\n", + " logit_scale=20.0,\n", + " model=model,\n", + " image_emb_gene=image_embedder,\n", + " data_dir=data_dir,\n", + ")\n", + "\n", + "plot_stat_map(\n", + " emotion_img,\n", + " display_mode=\"z\",\n", + " cut_coords=5,\n", + " threshold=2.0,\n", + " title=\"HCP Emotion (Faces vs Shapes)\",\n", + ")\n", + "plt.show()\n", + "\n", + "labels = [t[:38] + \"…\" if len(t) > 38 else t for t in task_df_custom[\"pred\"]]\n", + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.barh(range(len(task_df_custom)), task_df_custom[\"prob\"], color=\"mediumpurple\")\n", + "ax.set_yticks(range(len(task_df_custom)))\n", + "ax.set_yticklabels(labels, fontsize=8)\n", + "ax.invert_yaxis()\n", + "ax.set_xlabel(\"P(T|A)\")\n", + "ax.set_title(\"Emotion-focused vocabulary predictions\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0eb8c63", + "metadata": {}, + "outputs": [], + "source": [ + "# **Workflow B — embedding truly new task names (GPU required).**\n", + "#\n", + "# If you have task descriptions not present in the Cognitive Atlas, embed\n", + "# them with :class:`~braindec.embedding.TextEmbedding` and build a prior\n", + "# from uniform weights. The code below is shown for reference; a GPU with\n", + "# ≥14 GB VRAM (e.g., A100) is needed to run BrainGPT-7B.\n", + "\n", + "# .. code-block:: python\n", + "#\n", + "# from braindec.embedding import TextEmbedding\n", + "#\n", + "# my_tasks = [\n", + "# \"emotional conflict task\",\n", + "# \"social exclusion paradigm\",\n", + "# \"fear extinction training\",\n", + "# ]\n", + "#\n", + "# text_embedder = TextEmbedding(\n", + "# model_name=\"BrainGPT/BrainGPT-7B-v0.2\",\n", + "# batch_size=1,\n", + "# )\n", + "# my_vocabulary_emb = text_embedder(my_tasks) # shape (n_tasks, embedding_dim)\n", + "# my_prior = np.full(len(my_tasks), 1.0 / len(my_tasks))\n", + "#\n", + "# task_df_new = image_to_labels(\n", + "# emotion_img,\n", + "# model_path=model_fn,\n", + "# vocabulary=my_tasks,\n", + "# vocabulary_emb=my_vocabulary_emb,\n", + "# prior_probability=my_prior,\n", + "# topk=len(my_tasks),\n", + "# logit_scale=20.0,\n", + "# model=model,\n", + "# image_emb_gene=image_embedder,\n", + "# data_dir=data_dir,\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abfba100", + "metadata": {}, + "outputs": [], + "source": [ + "# Section 6: Latent space exploration\n", + "# ----------------------------------------\n", + "# NiCLIP learns a shared image–text embedding space. Here we visualise:\n", + "#\n", + "# * The DiFuMo-512 parcellation atlas used to project activation maps.\n", + "# * Cosine similarity between embedded HCP contrasts (image–image).\n", + "# * Cosine similarity between embedded HCP contrasts and vocabulary\n", + "# terms (image–text)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4743b03d", + "metadata": {}, + "outputs": [], + "source": [ + "# **6a. DiFuMo-512 parcellation atlas**\n", + "#\n", + "# NiCLIP compresses each activation map to a 512-dimensional vector using\n", + "# the `DiFuMo atlas `_ before\n", + "# passing it through the CLIP image encoder.\n", + "\n", + "from nilearn import datasets as nl_datasets\n", + "from nilearn.plotting import plot_roi\n", + "\n", + "difumo_kwargs = dict(dimension=512, resolution_mm=2,\n", + " data_dir=op.join(data_dir, \"nilearn\"))\n", + "try:\n", + " difumo = nl_datasets.fetch_atlas_difumo(legacy_format=False, **difumo_kwargs)\n", + "except TypeError:\n", + " difumo = nl_datasets.fetch_atlas_difumo(**difumo_kwargs)\n", + "\n", + "# Show a handful of DiFuMo components to illustrate the parcellation.\n", + "difumo_img = nib.load(difumo.maps)\n", + "n_components_to_show = 6\n", + "for comp_i in range(n_components_to_show):\n", + " comp_img = nli_image.index_img(difumo_img, comp_i)\n", + " plot_roi(\n", + " comp_img,\n", + " display_mode=\"z\",\n", + " cut_coords=1,\n", + " title=f\"DiFuMo component {comp_i + 1}\",\n", + " colorbar=False,\n", + " )\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14ddc5d7", + "metadata": {}, + "outputs": [], + "source": [ + "# **6b. Image–image cosine similarity across HCP contrasts**\n", + "#\n", + "# Embed all four downloaded HCP contrasts and measure how similar they are\n", + "# to each other in the shared CLIP latent space. Semantically related\n", + "# contrasts (e.g., tasks in the same cognitive domain) should cluster.\n", + "\n", + "from braindec.predict import preprocess_image\n", + "\n", + "contrast_names = list(hcp_paths.keys())\n", + "image_embeddings = {}\n", + "for domain, img_path in hcp_paths.items():\n", + " img = nib.load(img_path)\n", + " img_emb = preprocess_image(\n", + " img,\n", + " data_dir=data_dir,\n", + " image_emb_gene=image_embedder,\n", + " )\n", + " # Project through CLIP image encoder.\n", + " with torch.no_grad():\n", + " img_feat = model.encode_image(img_emb.to(device))\n", + " img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)\n", + " image_embeddings[domain] = img_feat.cpu().numpy().squeeze()\n", + "\n", + "# Compute pairwise cosine similarity.\n", + "n = len(contrast_names)\n", + "img_sim_matrix = np.zeros((n, n))\n", + "for i, d1 in enumerate(contrast_names):\n", + " for j, d2 in enumerate(contrast_names):\n", + " img_sim_matrix[i, j] = np.dot(image_embeddings[d1], image_embeddings[d2])\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 5))\n", + "im = ax.imshow(img_sim_matrix, vmin=-1, vmax=1, cmap=\"RdYlBu_r\")\n", + "ax.set_xticks(range(n))\n", + "ax.set_yticks(range(n))\n", + "ax.set_xticklabels(contrast_names, rotation=30, ha=\"right\")\n", + "ax.set_yticklabels(contrast_names)\n", + "plt.colorbar(im, ax=ax, label=\"Cosine similarity\")\n", + "ax.set_title(\"Image–image similarity in CLIP latent space\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0981db9e", + "metadata": {}, + "outputs": [], + "source": [ + "# **6c. Image–text similarity heatmap**\n", + "#\n", + "# Show how strongly each HCP contrast aligns with a curated set of\n", + "# vocabulary terms after projection through the CLIP encoders. High\n", + "# similarity scores (warm colours) indicate that NiCLIP associates a\n", + "# contrast with a given cognitive task.\n", + "\n", + "HIGHLIGHT_TASKS = [\n", + " \"motor fMRI task paradigm\",\n", + " \"language processing fMRI task paradigm\",\n", + " \"emotion processing fMRI task paradigm\",\n", + " \"working memory fMRI task paradigm\",\n", + " \"response inhibition\",\n", + " \"mental rotation\",\n", + " \"face recognition\",\n", + " \"attention\",\n", + "]\n", + "\n", + "# Find indices of the highlight tasks in the vocabulary.\n", + "highlight_idx = []\n", + "highlight_found = []\n", + "for task in HIGHLIGHT_TASKS:\n", + " if task in vocabulary:\n", + " highlight_idx.append(vocabulary.index(task))\n", + " highlight_found.append(task)\n", + " else:\n", + " # Fuzzy match: pick the vocabulary term with the most word overlap.\n", + " query_words = set(task.lower().split())\n", + " best_match = max(\n", + " range(len(vocabulary)),\n", + " key=lambda i: len(query_words & set(vocabulary[i].lower().split())),\n", + " )\n", + " highlight_idx.append(best_match)\n", + " highlight_found.append(vocabulary[best_match])\n", + "\n", + "# Text embeddings for the selected tasks (subset of precomputed array).\n", + "text_emb_subset = torch.from_numpy(vocabulary_emb[highlight_idx]).float().to(device)\n", + "text_emb_subset = text_emb_subset / (text_emb_subset.norm(dim=-1, keepdim=True) + 1e-8)\n", + "\n", + "with torch.no_grad():\n", + " text_feat_subset = model.encode_text(text_emb_subset)\n", + " text_feat_subset = text_feat_subset / text_feat_subset.norm(dim=-1, keepdim=True)\n", + "\n", + "text_feat_np = text_feat_subset.cpu().numpy()\n", + "\n", + "# Build image × text similarity matrix.\n", + "img_text_sim = np.zeros((n, len(highlight_found)))\n", + "for i, domain in enumerate(contrast_names):\n", + " img_text_sim[i] = text_feat_np @ image_embeddings[domain]\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 4))\n", + "im = ax.imshow(img_text_sim.T, aspect=\"auto\", cmap=\"RdYlBu_r\", vmin=-0.5, vmax=0.5)\n", + "ax.set_xticks(range(n))\n", + "ax.set_yticks(range(len(highlight_found)))\n", + "ax.set_xticklabels(contrast_names)\n", + "ax.set_yticklabels(\n", + " [t[:45] + \"…\" if len(t) > 45 else t for t in highlight_found],\n", + " fontsize=8,\n", + ")\n", + "plt.colorbar(im, ax=ax, label=\"Cosine similarity\")\n", + "ax.set_title(\"Image–text CLIP similarity: HCP contrasts × selected vocabulary terms\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85382905", + "metadata": {}, + "outputs": [], + "source": [ + "# **Summary**\n", + "#\n", + "# This tutorial demonstrated:\n", + "#\n", + "# * **Flat task decoding** (:func:`~braindec.predict.image_to_labels`) —\n", + "# direct task posterior probabilities from a group-level map.\n", + "# * **Hierarchical decoding** (:func:`~braindec.predict.image_to_labels_hierarchical`) —\n", + "# noisy-OR propagation to concept and domain levels.\n", + "# * **ROI characterization** — decoding anatomical binary masks to\n", + "# characterise brain regions without meta-analytic maps.\n", + "# * **Subject-level decoding** — applying the same pipeline to noisier\n", + "# single-subject maps (performance is lower; preprocessing choices matter).\n", + "# * **Custom vocabulary** — subsetting or replacing the Cognitive Atlas\n", + "# vocabulary with domain-specific task lists.\n", + "# * **Latent space exploration** — inspecting the shared image–text\n", + "# embedding space through cosine similarity matrices.\n", + "#\n", + "# Cite NiCLIP as:\n", + "#\n", + "# .. code-block:: text\n", + "#\n", + "# Peraza et al. (2025). NiCLIP: Neuroimaging contrastive language-image\n", + "# pretraining model for predicting text from brain activation images.\n", + "# bioRxiv. https://doi.org/10.1101/2025.06.14.659706" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true + }, + "gpuType": "T4", + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..2805610 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,75 @@ +"""Sphinx configuration for braindec documentation.""" + +import os +import sys + +DOCS_DIR = os.path.abspath(os.path.dirname(__file__)) +ROOT = os.path.abspath(os.path.join(DOCS_DIR, "..")) +sys.path.insert(0, ROOT) + +project = "braindec" +copyright = "2025, Braindec developers" +author = "Braindec developers" + +try: + from braindec._version import __version__ + release = __version__ +except ImportError: + release = "unknown" + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_copybutton", + "sphinx_gallery.gen_gallery", +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] + +autosummary_generate = True +autodoc_default_options = { + "members": True, + "undoc-members": False, + "show-inheritance": True, +} + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "nibabel": ("https://nipy.org/nibabel", None), + "nilearn": ("https://nilearn.github.io/stable", None), + "nimare": ("https://nimare.readthedocs.io/en/stable", None), + "torch": ("https://docs.pytorch.org/docs/stable", None), +} + +# sphinx-gallery configuration. +# Notebooks for Colab are generated separately via docs/make_notebooks.py +# (which uses jupytext and injects the braindec install cell). sphinx-gallery +# is used only to build the HTML gallery pages from the .py sources. +sphinx_gallery_conf = { + "examples_dirs": ["../examples"], + "gallery_dirs": ["auto_examples"], + # Only process files whose names start with two digits. + "filename_pattern": r"/\d{2}_", + # Set to True locally to execute examples and capture outputs. + # On ReadTheDocs the examples are not executed (too slow / GPU-dependent). + "plot_gallery": os.environ.get("BRAINDEC_BUILD_GALLERY", "0") == "1", + "remove_config_comments": True, + "show_memory": False, + "doc_module": ("braindec",), + "reference_url": {"braindec": None}, + "backreferences_dir": "gen_modules/backreferences", + "image_scrapers": ("matplotlib",), + "default_thumb_file": os.path.join(ROOT, "NiCLIP.png"), + "first_notebook_cell": ( + "%pip install \"braindec[plotting] @ git+https://github.com/jdkent/brain-decoder.git\"\n" + "%matplotlib inline" + ), +} diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..b02b2ab --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,71 @@ +braindec — NiCLIP documentation +================================ + +**braindec** is the Python package for +`NiCLIP `_, a contrastive +language–image pre-training model that decodes brain activation maps into +cognitive task descriptions from the +`Cognitive Atlas `_ ontology. + +.. toctree:: + :maxdepth: 1 + :caption: User guide + + installation + quickstart + auto_examples/index + +.. toctree:: + :maxdepth: 2 + :caption: API reference + + api + +Installation +------------ + +.. code-block:: bash + + pip install "braindec[plotting] @ git+https://github.com/jdkent/brain-decoder.git" + +Quickstart +---------- + +Download the example assets and run functional decoding in a few lines: + +.. code-block:: python + + from braindec.fetcher import download_bundle, get_data_dir + from braindec.predict import image_to_labels + + work_dir = get_data_dir() + download_bundle("example_prediction", destination_root=work_dir) + + # … construct paths, load model, then: + task_df = image_to_labels( + my_activation_map, + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + topk=10, + logit_scale=20.0, + ) + print(task_df) + +See the :doc:`examples gallery ` for a full walkthrough. + +Citation +-------- + +.. code-block:: text + + Peraza et al. (2025). NiCLIP: Neuroimaging contrastive language-image + pretraining model for predicting text from brain activation images. + bioRxiv. https://doi.org/10.1101/2025.06.14.659706 + +Indices +------- + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..2fbc10d --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,39 @@ +Installation +============ + +Install the package from the GitHub repository: + +.. code-block:: bash + + pip install "braindec @ git+https://github.com/jdkent/brain-decoder.git" + +Install the plotting extras when you want to run the tutorials or make brain +surface figures: + +.. code-block:: bash + + pip install "braindec[plotting] @ git+https://github.com/jdkent/brain-decoder.git" + +Development Install +------------------- + +For local development, clone the repository and install it in editable mode: + +.. code-block:: bash + + git clone https://github.com/jdkent/brain-decoder.git + cd brain-decoder + pip install -e ".[doc,plotting,test]" + +Documentation Build +------------------- + +Build the documentation locally without executing the examples: + +.. code-block:: bash + + python -m sphinx -b html docs docs/_build/html + +The gallery examples are not executed by default because the full NiCLIP +tutorial downloads model assets and may need GPU resources. To execute the +examples locally, set ``BRAINDEC_BUILD_GALLERY=1`` before building. diff --git a/docs/make_notebooks.py b/docs/make_notebooks.py new file mode 100644 index 0000000..485b39a --- /dev/null +++ b/docs/make_notebooks.py @@ -0,0 +1,104 @@ +"""Convert sphinx-gallery .py examples to committed .ipynb files for Colab. + +Usage: + python docs/make_notebooks.py + python docs/make_notebooks.py --examples examples/02_niclip_demo.py + +The generated notebooks are written to docs/auto_examples/ so that the Colab +badge URL in each example resolves to a real file on GitHub. Commit the +output alongside the .py source. + +The sphinx-gallery ``first_notebook_cell`` content (braindec install) is +injected as the first code cell of every notebook. +""" + +import argparse +import json +from pathlib import Path + +INSTALL_CELL_SOURCE = """\ +# Install braindec (this cell is only needed on Google Colab). +%pip install "braindec[plotting] @ git+https://github.com/jdkent/brain-decoder.git" +""" + +MATPLOTLIB_CELL_SOURCE = """\ +# Display Matplotlib figures inline in notebooks. +%matplotlib inline +""" + +INSTALL_CELL = { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": [line + "\n" for line in INSTALL_CELL_SOURCE.rstrip().splitlines()], +} + + +def py_to_notebook(py_path: Path, out_path: Path) -> None: + """Convert a sphinx-gallery .py file to an .ipynb, prepending the install cell.""" + from nbformat.v4 import new_code_cell + import jupytext + + notebook = jupytext.read(py_path, fmt="py:percent") + notebook.metadata["accelerator"] = "GPU" + notebook.metadata["gpuType"] = "T4" + notebook.metadata["colab"] = { + "gpuType": "T4", + "include_colab_link": True, + } + notebook.cells.insert(0, new_code_cell(INSTALL_CELL_SOURCE)) + notebook.cells.insert(1, new_code_cell(MATPLOTLIB_CELL_SOURCE)) + jupytext.write(notebook, out_path, fmt="ipynb") + print(f" {py_path.name} → {out_path}") + + +def main(argv=None): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--examples", + nargs="*", + default=None, + help="Specific .py files to convert. Defaults to all examples/NN_*.py files.", + ) + parser.add_argument( + "--stage", + action="store_true", + help="Run 'git add' on generated notebooks (used by the pre-commit hook).", + ) + opts = parser.parse_args(argv) + + repo_root = Path(__file__).parent.parent + out_dir = repo_root / "docs" / "auto_examples" + out_dir.mkdir(parents=True, exist_ok=True) + + if opts.examples: + sources = [Path(p) for p in opts.examples] + else: + sources = sorted((repo_root / "examples").glob("[0-9][0-9]_*.py")) + + if not sources: + print("No example files found.") + return + + try: + import jupytext # noqa: F401 + except ImportError: + raise SystemExit("jupytext is required: pip install jupytext") + + generated = [] + for src in sources: + dest = out_dir / src.with_suffix(".ipynb").name + py_to_notebook(src, dest) + generated.append(dest) + + if opts.stage and generated: + import subprocess + subprocess.run(["git", "add", *generated], check=True) + print("Staged generated notebooks.") + else: + print(f"\nDone. Commit the files in {out_dir} to enable the Colab badge links.") + + +if __name__ == "__main__": + main() diff --git a/docs/quickstart.rst b/docs/quickstart.rst new file mode 100644 index 0000000..85451c9 --- /dev/null +++ b/docs/quickstart.rst @@ -0,0 +1,37 @@ +Quickstart +========== + +The main prediction API decodes a brain activation map into Cognitive Atlas +task labels. The example below shows the shape of a typical workflow. + +.. code-block:: python + + import nibabel as nib + import numpy as np + + from braindec.fetcher import download_bundle, get_data_dir + from braindec.predict import image_to_labels + + work_dir = get_data_dir() + download_bundle("example_prediction", destination_root=work_dir) + + activation_img = nib.load("path/to/activation_map.nii.gz") + vocabulary = ["motor fMRI task paradigm", "language processing fMRI task paradigm"] + vocabulary_emb = np.load("path/to/vocabulary_embeddings.npy") + vocabulary_prior = np.full(len(vocabulary), 1.0 / len(vocabulary)) + + predictions = image_to_labels( + activation_img, + model_path="path/to/model.pth", + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + topk=10, + logit_scale=20.0, + ) + + print(predictions) + +For an end-to-end workflow with the packaged example assets, HCP contrast maps, +hierarchical decoding, ROI characterization, custom vocabularies, and latent +space plots, see :doc:`auto_examples/02_niclip_demo`. diff --git a/examples/02_niclip_demo.py b/examples/02_niclip_demo.py new file mode 100644 index 0000000..d8f7172 --- /dev/null +++ b/examples/02_niclip_demo.py @@ -0,0 +1,767 @@ +r""" +NiCLIP: Functional Brain Decoding Tutorial +=========================================== + +`NiCLIP `_ is a contrastive +language–image pre-training (CLIP) model trained on ~23,000 neuroimaging +articles that maps brain activation patterns to cognitive task descriptions +from the `Cognitive Atlas `_ ontology. + +This tutorial walks through the main use cases: + +1. **Group-level task decoding** — predict tasks, concepts, and cognitive + process domains from a group-level activation map. +2. **Hierarchical decoding** — obtain predictions at three ontology levels + (tasks → concepts → domains) using the noisy-OR propagation rule. +3. **Brain region characterization** — characterize anatomical ROIs + without pre-computed meta-analytic maps. +4. **Subject-level decoding** — apply decoding to noisier single-subject maps. +5. **Custom vocabulary** — decode against a user-supplied task vocabulary. +6. **Latent space exploration** — visualize the shared image–text embedding + space learned by NiCLIP. + +.. note:: + + **Run this tutorial on Google Colab** + + .. image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/github/jdkent/brain-decoder/blob/main/docs/auto_examples/02_niclip_demo.ipynb + :alt: Open In Colab + + The first notebook cell installs :code:`braindec` and its dependencies + automatically. The full install takes a few minutes on a fresh Colab + runtime; subsequent runs reuse the cached packages. +""" + +# %% +# Section 0: Download example assets +# ------------------------------------ +# NiCLIP ships a curated set of publishable assets on +# `OSF `_. The ``example_prediction`` bundle +# contains the pre-trained CLIP model, reduced Cognitive Atlas vocabulary, +# pre-computed vocabulary embeddings, vocabulary prior, brain mask, and +# Cognitive Atlas ontology snapshots — everything required to run decoding. +# +# The download is skipped automatically for files that already exist locally. + +import os +import os.path as op +from pathlib import Path + +import nibabel as nib +import numpy as np +import pandas as pd +import requests + +from braindec.fetcher import download_bundle, get_data_dir + +work_dir = get_data_dir() +print(f"Working directory: {work_dir}") + +downloaded = download_bundle("example_prediction", destination_root=work_dir) +print(f"Bundle contains {len(downloaded)} files") + +# %% +# Construct paths to downloaded assets. +# These mirror the OSF folder structure that ``download_bundle`` preserves. + +MODEL_NAME = "BrainGPT-7B-v0.2" +SECTION = "body" +SOURCE = "cogatlasred" +VOC_LABEL = f"vocabulary-{SOURCE}_task-combined_embedding-{MODEL_NAME}" + +data_dir = op.join(work_dir, "data") +results_dir = op.join(work_dir, "results") +voc_dir = op.join(data_dir, "vocabulary") +cog_atlas_dir = op.join(data_dir, "cognitive_atlas") + +model_fn = op.join(results_dir, "pubmed", + f"model-clip_section-{SECTION}_embedding-{MODEL_NAME}_best.pth") +vocabulary_fn = op.join(voc_dir, f"vocabulary-{SOURCE}_task.txt") +vocabulary_emb_fn = op.join(voc_dir, f"{VOC_LABEL}.npy") +vocabulary_prior_fn = op.join(voc_dir, f"{VOC_LABEL}_section-{SECTION}_prior.npy") +mask_fn = op.join(data_dir, "MNI152_2x2x2_brainmask.nii.gz") + +for label, path in [ + ("model", model_fn), + ("vocabulary", vocabulary_fn), + ("vocabulary embeddings", vocabulary_emb_fn), + ("vocabulary prior", vocabulary_prior_fn), + ("brain mask", mask_fn), + ("cognitive atlas", cog_atlas_dir), +]: + status = "✓" if op.exists(path) else "✗ MISSING" + print(f" {status} {label}: {path}") + +# %% +# Download representative HCP group-level contrast maps from +# `NeuroVault `_ (public access, +# no account required). We use three contrasts that span different cognitive +# domains to illustrate decoding across sections 1–6. + +HCP_COLLECTION_ID = 457 +HCP_MAPS = { + "motor": "tfMRI_MOTOR_AVG_zstat1.nii.gz", + "language": "tfMRI_LANGUAGE_STORY-MATH_zstat1.nii.gz", + "emotion": "tfMRI_EMOTION_FACES-SHAPES_zstat1.nii.gz", + "working_memory": "tfMRI_WM_2BK-0BK_zstat1.nii.gz", +} + +hcp_dir = Path(data_dir) / "hcp" / "neurovault" +hcp_dir.mkdir(parents=True, exist_ok=True) + +hcp_paths = {} +for domain, filename in HCP_MAPS.items(): + dest = hcp_dir / filename + if not dest.exists(): + url = f"https://neurovault.org/media/images/{HCP_COLLECTION_ID}/{filename}" + print(f"Downloading {domain} map …") + with requests.get(url, stream=True, timeout=120) as r: + r.raise_for_status() + with open(dest, "wb") as fh: + for chunk in r.iter_content(chunk_size=1024 * 1024): + fh.write(chunk) + hcp_paths[domain] = str(dest) + print(f" {domain}: {dest}") + +# %% +# Initialise shared resources that will be reused across sections. +# Building the model and image embedder once avoids repeated I/O and +# DiFuMo atlas downloads. + +import torch +from braindec.cogatlas import CognitiveAtlas +from braindec.embedding import ImageEmbedding +from braindec.model import build_model +from braindec.utils import _get_device + +device = _get_device() +print(f"Using device: {device}") + +model = build_model(model_fn, device=device) + +image_embedder = ImageEmbedding( + standardize=False, + nilearn_dir=op.join(data_dir, "nilearn"), + space="MNI152", +) + +with open(vocabulary_fn) as fh: + vocabulary = [line.strip() for line in fh] +vocabulary_emb = np.load(vocabulary_emb_fn) +vocabulary_prior = np.load(vocabulary_prior_fn) + +print(f"Vocabulary size: {len(vocabulary)} tasks") +print(f"Embedding shape: {vocabulary_emb.shape}") + +# %% +# Section 1: Group-level task decoding +# ---------------------------------------- +# The primary NiCLIP use case is *functional decoding*: given a brain +# activation map, retrieve the most likely cognitive tasks from the Cognitive +# Atlas. NiCLIP computes posterior probabilities P(T|A) using Bayes' +# theorem over the CLIP cosine similarities. + +import matplotlib.pyplot as plt +from nilearn.plotting import plot_stat_map + +from braindec.predict import image_to_labels + +motor_img = nib.load(hcp_paths["motor"]) + +task_df = image_to_labels( + motor_img, + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + topk=10, + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, +) + +print("Top-10 task predictions for HCP Motor (AVG) contrast:") +print(task_df.to_string(index=False)) + +# %% +# Visualise the input activation map and the top-5 task predictions. + +plot_stat_map( + motor_img, + display_mode="z", + cut_coords=5, + colorbar=True, + threshold=2.0, + title="HCP Motor (AVG) z-stat", +) +plt.show() + +top5 = task_df.head(5) +short_labels = [t[:40] + "…" if len(t) > 40 else t for t in top5["pred"]] +fig, ax = plt.subplots(figsize=(8, 4)) +ax.barh(range(len(top5)), top5["prob"], color="steelblue") +ax.set_yticks(range(len(top5))) +ax.set_yticklabels(short_labels, fontsize=9) +ax.invert_yaxis() +ax.set_xlabel("Posterior probability P(T|A)") +ax.set_title("Top-5 task predictions") +ax.set_xlim(0, top5["prob"].max() * 1.2) +plt.tight_layout() +plt.show() + +# %% +# Section 2: Hierarchical decoding +# ------------------------------------- +# NiCLIP propagates task posteriors up the Cognitive Atlas ontology using +# a noisy-OR model to derive concept and cognitive process domain +# probabilities: P(C|A) and P(D|A). This produces interpretations at +# three levels of specificity. + +import json + +from braindec.predict import image_to_labels_hierarchical + +concept_to_process_fn = op.join(cog_atlas_dir, "concept_to_process.json") +with open(concept_to_process_fn) as fh: + concept_to_process = json.load(fh) + +reduced_tasks_df = pd.read_csv(op.join(cog_atlas_dir, "reduced_tasks.csv")) + +cog_atlas = CognitiveAtlas( + data_dir=data_dir, + task_snapshot=op.join(cog_atlas_dir, "task_snapshot-02-19-25.json"), + concept_snapshot=op.join(cog_atlas_dir, "concept_extended_snapshot-02-19-25.json"), + concept_to_process=concept_to_process, + reduced_tasks=reduced_tasks_df, +) + +print(f"Cognitive Atlas: {len(cog_atlas.task_names)} tasks | " + f"{len(cog_atlas.concept_names)} concepts | " + f"{len(cog_atlas.process_names)} domains") + +# %% + +task_df_h, concept_df, domain_df = image_to_labels_hierarchical( + motor_img, + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + cognitiveatlas=cog_atlas, + topk=5, + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, +) + +# %% +# Display predictions at all three ontology levels. + +fig, axes = plt.subplots(1, 3, figsize=(16, 4)) +panels = [ + (task_df_h, "Tasks P(T|A)", "prob"), + (concept_df, "Concepts P(C|A)", "prob"), + (domain_df, "Domains P(D|A)", "prob"), +] + +for ax, (df, title, col) in zip(axes, panels): + df_top = df.head(5) + labels = [t[:35] + "…" if len(t) > 35 else t for t in df_top["pred"]] + ax.barh(range(len(df_top)), df_top[col], color="steelblue") + ax.set_yticks(range(len(df_top))) + ax.set_yticklabels(labels, fontsize=8) + ax.invert_yaxis() + ax.set_xlabel("Posterior probability") + ax.set_title(title, fontsize=10) + ax.set_xlim(0, df_top[col].max() * 1.3) + +fig.suptitle("HCP Motor (AVG) — Hierarchical decoding", fontsize=12) +plt.tight_layout() +plt.show() + +# %% +# Section 3: Brain region characterization +# ------------------------------------------- +# Instead of a task activation map, NiCLIP can decode *anatomical ROIs* +# directly — enabling functional characterisation of brain regions without +# requiring pre-computed meta-analytic maps. This supports hypothesis +# generation about the cognitive roles of unstudied regions. +# +# We create binary ROI masks from the Harvard-Oxford atlas available via +# nilearn. The subcortical atlas provides named amygdala, hippocampus, and +# striatal regions. For cortical examples not represented in this atlas, we +# add simple spherical MNI ROIs so the tutorial still covers a broader set of +# functional regions. + +from nilearn import datasets, image as nli_image + +mask_img = nib.load(mask_fn) +mask_data = mask_img.get_fdata() > 0 + + +def make_label_roi(atlas_img, atlas_labels, region_names): + """Create a binary ROI from one or more deterministic atlas label names.""" + atlas_data = atlas_img.get_fdata() + label_to_idx = {name: idx for idx, name in enumerate(atlas_labels)} + roi = np.zeros(atlas_data.shape, dtype=np.float32) + for name in region_names: + if name in label_to_idx: + roi[atlas_data == label_to_idx[name]] = 1.0 + else: + print(f" Warning: '{name}' not found in atlas.") + return nib.Nifti1Image(roi, atlas_img.affine, atlas_img.header) + + +def make_spherical_roi(center_xyz, radius_mm=8): + """Create a spherical ROI in MNI millimeter coordinates.""" + ijk = np.indices(mask_img.shape).reshape(3, -1).T + xyz = nib.affines.apply_affine(mask_img.affine, ijk) + distances = np.linalg.norm(xyz - np.asarray(center_xyz), axis=1) + roi = (distances <= radius_mm).reshape(mask_img.shape) & mask_data + return nib.Nifti1Image(roi.astype(np.float32), mask_img.affine, mask_img.header) + + +ROI_LABEL_SPECS = { + "Amygdala": ["Left Amygdala", "Right Amygdala"], + "Hippocampus": ["Left Hippocampus", "Right Hippocampus"], + "Striatum": ["Left Putamen", "Right Putamen", "Left Caudate", "Right Caudate"], +} + +ROI_COORD_SPECS = { + "Amygdala": (-22, -4, -18), + "Hippocampus": (-26, -20, -14), + "Insula": (-34, 18, 4), + "Striatum": (-18, 8, 4), + "rTPJ": (54, -54, 24), + "vmPFC": (0, 46, -8), +} + +try: + ho_sub = datasets.fetch_atlas_harvard_oxford( + "sub-maxprob-thr25-2mm", + data_dir=op.join(data_dir, "nilearn"), + verbose=1, + ) + roi_images = { + name: make_label_roi(ho_sub.maps, ho_sub.labels, labels) + for name, labels in ROI_LABEL_SPECS.items() + } + roi_images.update({ + name: make_spherical_roi(center) + for name, center in ROI_COORD_SPECS.items() + if name not in roi_images + }) + print("Using Harvard-Oxford atlas ROIs with spherical cortical examples.") +except Exception as exc: + print(f"Harvard-Oxford atlas download failed ({type(exc).__name__}: {exc})") + print("Using fallback spherical MNI ROIs for this tutorial run.") + roi_images = {name: make_spherical_roi(center) for name, center in ROI_COORD_SPECS.items()} + +# %% +# Decode each ROI and collect the top task, concept, and domain. + +roi_summary = [] +for roi_name, roi_img in roi_images.items(): + t_df, c_df, d_df = image_to_labels_hierarchical( + roi_img, + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + cognitiveatlas=cog_atlas, + topk=3, + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, + ) + roi_summary.append({ + "ROI": roi_name, + "Top task": t_df.iloc[0]["pred"], + "Task P(T|A)": f"{t_df.iloc[0]['prob']:.3f}", + "Top concept": c_df.iloc[0]["pred"], + "Top domain": d_df.iloc[0]["pred"], + }) + +summary_df = pd.DataFrame(roi_summary) +print(summary_df.to_string(index=False)) + +# %% +# Visualise one ROI alongside its top prediction. + +from nilearn.plotting import plot_roi + +plot_roi( + roi_images["Amygdala"], + title="Amygdala (bilateral)", + display_mode="ortho", + cut_coords=(0, -4, -18), + colorbar=False, +) +plt.show() + +# Decode the amygdala with a finer top-k for the bar chart. +t_df, c_df, d_df = image_to_labels_hierarchical( + roi_images["Amygdala"], + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + cognitiveatlas=cog_atlas, + topk=5, + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, +) + +labels = [t[:38] + "…" if len(t) > 38 else t for t in t_df["pred"]] +fig, ax = plt.subplots(figsize=(8, 4)) +ax.barh(range(5), t_df["prob"], color="salmon") +ax.set_yticks(range(5)) +ax.set_yticklabels(labels, fontsize=8) +ax.invert_yaxis() +ax.set_xlabel("P(T|A)") +ax.set_title("Top-5 task predictions for Amygdala ROI") + +plt.tight_layout() +plt.show() + +# %% +# Section 4: Subject-level decoding +# ------------------------------------- +# NiCLIP can decode single-subject activation maps, though performance is +# lower than group-level due to higher noise. Here we simulate a +# subject-level map by adding Gaussian noise to the group-level motor +# contrast, then compare predicted ranks to the clean result. +# +# In practice, you would supply your own subject-level t-stat or z-stat +# NIfTI image in place of the simulated map below. + +motor_data = motor_img.get_fdata() +rng = np.random.default_rng(42) +noise_std = motor_data.std() +noisy_data = motor_data + rng.normal(scale=noise_std, size=motor_data.shape) +noisy_img = nib.Nifti1Image(noisy_data, motor_img.affine, motor_img.header) + +task_df_noisy = image_to_labels( + noisy_img, + model_path=model_fn, + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=vocabulary_prior, + topk=10, + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, +) + +# %% +# Compare predictions from the clean group map vs. the simulated subject map. + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) +for ax, (df, title, color) in zip( + axes, + [ + (task_df.head(5), "Group-level (clean)", "steelblue"), + (task_df_noisy.head(5), "Subject-level (simulated noise)", "orange"), + ], +): + labels = [t[:40] + "…" if len(t) > 40 else t for t in df["pred"]] + ax.barh(range(len(df)), df["prob"], color=color) + ax.set_yticks(range(len(df))) + ax.set_yticklabels(labels, fontsize=8) + ax.invert_yaxis() + ax.set_xlabel("P(T|A)") + ax.set_title(title) + +fig.suptitle("Motor decoding: group vs. subject-level noise", fontsize=12) +plt.tight_layout() +plt.show() + +# %% +# Section 5: Custom vocabulary decoding +# ----------------------------------------- +# NiCLIP accepts any list of task names paired with their LLM-derived text +# embeddings. This lets you decode against a domain-specific vocabulary +# instead of (or in addition to) the full Cognitive Atlas. +# +# **Two workflows:** +# +# *Workflow A — subset the existing vocabulary.* +# Select a subset of Cognitive Atlas tasks relevant to your study domain +# and decode with just those terms. No additional embedding needed. +# +# *Workflow B — embed entirely new task names.* +# Use :class:`~braindec.embedding.TextEmbedding` with BrainGPT to embed +# custom task descriptions, then pass them directly to +# :func:`~braindec.predict.image_to_labels`. + +# Workflow A: emotion-focused vocabulary subset +EMOTION_KEYWORDS = ["emotion", "fear", "affect", "face", "amygdala", "valence", "threat"] + +custom_idx = [ + i for i, task in enumerate(vocabulary) + if any(kw in task.lower() for kw in EMOTION_KEYWORDS) +] +custom_vocabulary = [vocabulary[i] for i in custom_idx] +custom_vocabulary_emb = vocabulary_emb[custom_idx] +custom_prior = vocabulary_prior[custom_idx] +# Re-normalise prior so probabilities sum to 1. +custom_prior = custom_prior / custom_prior.sum() + +print(f"Custom emotion vocabulary: {len(custom_vocabulary)} tasks") +print(" " + "\n ".join(custom_vocabulary[:8])) + +# %% + +emotion_img = nib.load(hcp_paths["emotion"]) + +task_df_custom = image_to_labels( + emotion_img, + model_path=model_fn, + vocabulary=custom_vocabulary, + vocabulary_emb=custom_vocabulary_emb, + prior_probability=custom_prior, + topk=min(8, len(custom_vocabulary)), + logit_scale=20.0, + model=model, + image_emb_gene=image_embedder, + data_dir=data_dir, +) + +plot_stat_map( + emotion_img, + display_mode="z", + cut_coords=5, + threshold=2.0, + title="HCP Emotion (Faces vs Shapes)", +) +plt.show() + +labels = [t[:38] + "…" if len(t) > 38 else t for t in task_df_custom["pred"]] +fig, ax = plt.subplots(figsize=(8, 4)) +ax.barh(range(len(task_df_custom)), task_df_custom["prob"], color="mediumpurple") +ax.set_yticks(range(len(task_df_custom))) +ax.set_yticklabels(labels, fontsize=8) +ax.invert_yaxis() +ax.set_xlabel("P(T|A)") +ax.set_title("Emotion-focused vocabulary predictions") +plt.tight_layout() +plt.show() + +# %% +# **Workflow B — embedding truly new task names (GPU required).** +# +# If you have task descriptions not present in the Cognitive Atlas, embed +# them with :class:`~braindec.embedding.TextEmbedding` and build a prior +# from uniform weights. The code below is shown for reference; a GPU with +# ≥14 GB VRAM (e.g., A100) is needed to run BrainGPT-7B. + +# .. code-block:: python +# +# from braindec.embedding import TextEmbedding +# +# my_tasks = [ +# "emotional conflict task", +# "social exclusion paradigm", +# "fear extinction training", +# ] +# +# text_embedder = TextEmbedding( +# model_name="BrainGPT/BrainGPT-7B-v0.2", +# batch_size=1, +# ) +# my_vocabulary_emb = text_embedder(my_tasks) # shape (n_tasks, embedding_dim) +# my_prior = np.full(len(my_tasks), 1.0 / len(my_tasks)) +# +# task_df_new = image_to_labels( +# emotion_img, +# model_path=model_fn, +# vocabulary=my_tasks, +# vocabulary_emb=my_vocabulary_emb, +# prior_probability=my_prior, +# topk=len(my_tasks), +# logit_scale=20.0, +# model=model, +# image_emb_gene=image_embedder, +# data_dir=data_dir, +# ) + +# %% +# Section 6: Latent space exploration +# ---------------------------------------- +# NiCLIP learns a shared image–text embedding space. Here we visualise: +# +# * The DiFuMo-512 parcellation atlas used to project activation maps. +# * Cosine similarity between embedded HCP contrasts (image–image). +# * Cosine similarity between embedded HCP contrasts and vocabulary +# terms (image–text). + +# %% +# **6a. DiFuMo-512 parcellation atlas** +# +# NiCLIP compresses each activation map to a 512-dimensional vector using +# the `DiFuMo atlas `_ before +# passing it through the CLIP image encoder. + +from nilearn import datasets as nl_datasets +from nilearn.plotting import plot_roi + +difumo_kwargs = dict(dimension=512, resolution_mm=2, + data_dir=op.join(data_dir, "nilearn")) +try: + difumo = nl_datasets.fetch_atlas_difumo(legacy_format=False, **difumo_kwargs) +except TypeError: + difumo = nl_datasets.fetch_atlas_difumo(**difumo_kwargs) + +# Show a handful of DiFuMo components to illustrate the parcellation. +difumo_img = nib.load(difumo.maps) +n_components_to_show = 6 +for comp_i in range(n_components_to_show): + comp_img = nli_image.index_img(difumo_img, comp_i) + plot_roi( + comp_img, + display_mode="z", + cut_coords=1, + title=f"DiFuMo component {comp_i + 1}", + colorbar=False, + ) +plt.show() + +# %% +# **6b. Image–image cosine similarity across HCP contrasts** +# +# Embed all four downloaded HCP contrasts and measure how similar they are +# to each other in the shared CLIP latent space. Semantically related +# contrasts (e.g., tasks in the same cognitive domain) should cluster. + +from braindec.predict import preprocess_image + +contrast_names = list(hcp_paths.keys()) +image_embeddings = {} +for domain, img_path in hcp_paths.items(): + img = nib.load(img_path) + img_emb = preprocess_image( + img, + data_dir=data_dir, + image_emb_gene=image_embedder, + ) + # Project through CLIP image encoder. + with torch.no_grad(): + img_feat = model.encode_image(img_emb.to(device)) + img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) + image_embeddings[domain] = img_feat.cpu().numpy().squeeze() + +# Compute pairwise cosine similarity. +n = len(contrast_names) +img_sim_matrix = np.zeros((n, n)) +for i, d1 in enumerate(contrast_names): + for j, d2 in enumerate(contrast_names): + img_sim_matrix[i, j] = np.dot(image_embeddings[d1], image_embeddings[d2]) + +fig, ax = plt.subplots(figsize=(6, 5)) +im = ax.imshow(img_sim_matrix, vmin=-1, vmax=1, cmap="RdYlBu_r") +ax.set_xticks(range(n)) +ax.set_yticks(range(n)) +ax.set_xticklabels(contrast_names, rotation=30, ha="right") +ax.set_yticklabels(contrast_names) +plt.colorbar(im, ax=ax, label="Cosine similarity") +ax.set_title("Image–image similarity in CLIP latent space") +plt.tight_layout() +plt.show() + +# %% +# **6c. Image–text similarity heatmap** +# +# Show how strongly each HCP contrast aligns with a curated set of +# vocabulary terms after projection through the CLIP encoders. High +# similarity scores (warm colours) indicate that NiCLIP associates a +# contrast with a given cognitive task. + +HIGHLIGHT_TASKS = [ + "motor fMRI task paradigm", + "language processing fMRI task paradigm", + "emotion processing fMRI task paradigm", + "working memory fMRI task paradigm", + "response inhibition", + "mental rotation", + "face recognition", + "attention", +] + +# Find indices of the highlight tasks in the vocabulary. +highlight_idx = [] +highlight_found = [] +for task in HIGHLIGHT_TASKS: + if task in vocabulary: + highlight_idx.append(vocabulary.index(task)) + highlight_found.append(task) + else: + # Fuzzy match: pick the vocabulary term with the most word overlap. + query_words = set(task.lower().split()) + best_match = max( + range(len(vocabulary)), + key=lambda i: len(query_words & set(vocabulary[i].lower().split())), + ) + highlight_idx.append(best_match) + highlight_found.append(vocabulary[best_match]) + +# Text embeddings for the selected tasks (subset of precomputed array). +text_emb_subset = torch.from_numpy(vocabulary_emb[highlight_idx]).float().to(device) +text_emb_subset = text_emb_subset / (text_emb_subset.norm(dim=-1, keepdim=True) + 1e-8) + +with torch.no_grad(): + text_feat_subset = model.encode_text(text_emb_subset) + text_feat_subset = text_feat_subset / text_feat_subset.norm(dim=-1, keepdim=True) + +text_feat_np = text_feat_subset.cpu().numpy() + +# Build image × text similarity matrix. +img_text_sim = np.zeros((n, len(highlight_found))) +for i, domain in enumerate(contrast_names): + img_text_sim[i] = text_feat_np @ image_embeddings[domain] + +fig, ax = plt.subplots(figsize=(10, 4)) +im = ax.imshow(img_text_sim.T, aspect="auto", cmap="RdYlBu_r", vmin=-0.5, vmax=0.5) +ax.set_xticks(range(n)) +ax.set_yticks(range(len(highlight_found))) +ax.set_xticklabels(contrast_names) +ax.set_yticklabels( + [t[:45] + "…" if len(t) > 45 else t for t in highlight_found], + fontsize=8, +) +plt.colorbar(im, ax=ax, label="Cosine similarity") +ax.set_title("Image–text CLIP similarity: HCP contrasts × selected vocabulary terms") +plt.tight_layout() +plt.show() + +# %% +# **Summary** +# +# This tutorial demonstrated: +# +# * **Flat task decoding** (:func:`~braindec.predict.image_to_labels`) — +# direct task posterior probabilities from a group-level map. +# * **Hierarchical decoding** (:func:`~braindec.predict.image_to_labels_hierarchical`) — +# noisy-OR propagation to concept and domain levels. +# * **ROI characterization** — decoding anatomical binary masks to +# characterise brain regions without meta-analytic maps. +# * **Subject-level decoding** — applying the same pipeline to noisier +# single-subject maps (performance is lower; preprocessing choices matter). +# * **Custom vocabulary** — subsetting or replacing the Cognitive Atlas +# vocabulary with domain-specific task lists. +# * **Latent space exploration** — inspecting the shared image–text +# embedding space through cosine similarity matrices. +# +# Cite NiCLIP as: +# +# .. code-block:: text +# +# Peraza et al. (2025). NiCLIP: Neuroimaging contrastive language-image +# pretraining model for predicting text from brain activation images. +# bioRxiv. https://doi.org/10.1101/2025.06.14.659706 diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 0000000..d2fb4c2 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,9 @@ +Examples +======== + +These examples demonstrate common ``braindec`` workflows. + +The gallery is built without executing examples on Read the Docs because the +full NiCLIP demo downloads model assets and can require GPU resources. Build +locally with ``BRAINDEC_BUILD_GALLERY=1`` to execute examples and capture +outputs. diff --git a/jobs/__init__.py b/jobs/__init__.py new file mode 100644 index 0000000..ea20182 --- /dev/null +++ b/jobs/__init__.py @@ -0,0 +1 @@ +"""Job helpers and runnable analysis scripts for braindec.""" diff --git a/jobs/build_dataset_mappings.py b/jobs/build_dataset_mappings.py new file mode 100644 index 0000000..25910dd --- /dev/null +++ b/jobs/build_dataset_mappings.py @@ -0,0 +1,123 @@ +"""Build evaluation-ready and representative mapping tables for NeuroVault datasets.""" + +import argparse +import json +import os +import os.path as op +from pathlib import Path + +import pandas as pd + + +IBC_REPRESENTATIVE_RULES = { + "archi_emotional": "nv40641-sub-01-ses-07-expression-intention-gender", + "archi_social": "nv40616-sub-01-ses-00-triangle-mental-random", + "archi_spatial": "nv40602-sub-01-ses-00-saccades", + "archi_standard": "nv40564-sub-01-ses-00-computation-sentences", + "hcp_emotion": "nv40035-sub-01-ses-03-face-shape", + "hcp_gambling": "nv40020-sub-01-ses-03-reward", + "hcp_language": "nv40012-sub-01-ses-03-story-math", + "hcp_motor": "nv40023-sub-01-ses-03-left-hand-avg", + "hcp_relational": "nv40038-sub-01-ses-04-relational-match", + "hcp_social": "nv40016-sub-01-ses-04-mental-random", + "hcp_wm": "nv40041-sub-01-ses-04-2back-0back", + "language_nsp": "nv40654-sub-01-ses-05-complex-simple", +} + +CNP_REPRESENTATIVE_RULES = { + "BART": "nv49974-BART-Accept", + "PAMRET": "nv49999-PAMRET-All", + "SCAP": "nv50048-SCAP-All", + "STOPSIGNAL": "nv50088-STOPSIGNAL-Go-StopSuccess", + "TASKSWITCH": "nv50104-TASKSWITCH-ALL", +} + + +def _load_full_task_set(task_snapshot_fn): + with open(task_snapshot_fn, "r") as file_obj: + tasks = json.load(file_obj) + return {task["name"].strip() for task in tasks if task.get("name")} + + +def _materialize_representative_dir(mapping_df, source_dir, destination_dir): + source_dir = Path(source_dir) + destination_dir = Path(destination_dir) + destination_dir.mkdir(parents=True, exist_ok=True) + for row in mapping_df.itertuples(index=False): + source_path = Path(row.local_path) + link_path = destination_dir / source_path.name + if link_path.exists() or link_path.is_symlink(): + continue + rel_target = os.path.relpath(source_path, start=destination_dir) + link_path.symlink_to(rel_target) + + +def _build_mapping_tables(data_dir, dataset_name, representative_rules): + dataset_dir = Path(data_dir) / dataset_name + mapping_df = pd.read_csv(dataset_dir / "mapping.csv") + mapping_df["task_name"] = mapping_df["task_name"].fillna("").astype(str).str.strip() + + reduced_df = pd.read_csv(Path(data_dir) / "cognitive_atlas" / "reduced_tasks.csv") + reduced_tasks = {task.strip() for task in reduced_df["task"].tolist()} + full_tasks = _load_full_task_set(Path(data_dir) / "cognitive_atlas" / "task_snapshot-02-19-25.json") + + mapping_df["in_reduced_ontology"] = mapping_df["task_name"].isin(reduced_tasks) + mapping_df["in_full_ontology"] = mapping_df["task_name"].isin(full_tasks) + + reduced_mapping_df = mapping_df.loc[mapping_df["in_reduced_ontology"]].copy() + full_mapping_df = mapping_df.loc[mapping_df["in_full_ontology"]].copy() + + representative_rows = [] + for task_family, prediction_label in representative_rules.items(): + matched = mapping_df.loc[mapping_df["prediction_label"] == prediction_label] + if matched.empty: + raise KeyError(f"Could not find representative label {prediction_label!r} for {dataset_name}:{task_family}.") + representative_rows.append(matched.iloc[0].to_dict()) + representative_df = pd.DataFrame(representative_rows).sort_values(by="task_family").reset_index(drop=True) + + reduced_representative_df = representative_df.loc[representative_df["in_reduced_ontology"]].copy() + full_representative_df = representative_df.loc[representative_df["in_full_ontology"]].copy() + + reduced_mapping_df.to_csv(dataset_dir / "mapping_reduced.csv", index=False) + full_mapping_df.to_csv(dataset_dir / "mapping_full.csv", index=False) + representative_df.to_csv(dataset_dir / "mapping_representative.csv", index=False) + reduced_representative_df.to_csv(dataset_dir / "mapping_representative_reduced.csv", index=False) + full_representative_df.to_csv(dataset_dir / "mapping_representative_full.csv", index=False) + + representative_dir = Path(data_dir) / f"{dataset_name}_representative" + _materialize_representative_dir(representative_df, dataset_dir, representative_dir) + + return { + "mapping": mapping_df, + "reduced_mapping": reduced_mapping_df, + "full_mapping": full_mapping_df, + "representative_mapping": representative_df, + "reduced_representative_mapping": reduced_representative_df, + "full_representative_mapping": full_representative_df, + } + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Build evaluation-ready dataset mapping tables.") + parser.add_argument( + "--project_dir", + dest="project_dir", + default=op.abspath(op.join(op.dirname(__file__), "..")), + help="Path to the repository root.", + ) + return parser + + +def main(project_dir): + data_dir = op.join(op.abspath(project_dir), "data") + _build_mapping_tables(data_dir, "ibc", IBC_REPRESENTATIVE_RULES) + _build_mapping_tables(data_dir, "cnp", CNP_REPRESENTATIVE_RULES) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/decoding_cnp.py b/jobs/decoding_cnp.py new file mode 100644 index 0000000..d261b1d --- /dev/null +++ b/jobs/decoding_cnp.py @@ -0,0 +1,301 @@ +"""Decode CNP maps into task, concept, and domain predictions. + +Expected inputs: +- NIfTI maps in `data/cnp` by default. +- Prediction labels are derived from filenames and written into filenames like + `__section-
_pred-task_brainclip.csv`. + +Use a companion mapping CSV with `prediction_label` plus task metadata when running +`jobs/decoding_eval.py` on the resulting outputs. +""" + +import argparse +import itertools +import os +import os.path as op +from glob import glob + +import pandas as pd +from tqdm.auto import tqdm + +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + add_common_job_args, + build_cognitiveatlas, + get_source, + infer_prediction_label, + load_decoding_resources, + resolve_project_paths, +) + + +def _get_parser(): + parser = add_common_job_args(argparse.ArgumentParser(description="Decode CNP maps")) + parser.add_argument( + "--image_dir", + dest="image_dir", + default=None, + help="Optional explicit path to the CNP image directory.", + ) + parser.add_argument( + "--output_dir", + dest="output_dir", + default=None, + help="Optional explicit path to the output directory.", + ) + parser.add_argument( + "--categories", + dest="categories", + nargs="+", + default=["task"], + help="Vocabulary categories to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined"], + help="Vocabulary embedding variants to evaluate.", + ) + parser.add_argument( + "--label_delimiter", + dest="label_delimiter", + default="_", + help="Delimiter used to split image stems when deriving prediction labels.", + ) + parser.add_argument( + "--label_token_index", + dest="label_token_index", + type=int, + default=0, + help="Token index used to derive the prediction label from the image stem.", + ) + parser.add_argument( + "--mapping_fn", + dest="mapping_fn", + default=None, + help="Optional mapping CSV with a local_path column used to select images.", + ) + parser.add_argument( + "--num_shards", + dest="num_shards", + type=int, + default=1, + help="Split the selected image list into this many shards.", + ) + parser.add_argument( + "--shard_index", + dest="shard_index", + type=int, + default=0, + help="0-based shard index to decode from the sharded image list.", + ) + parser.add_argument( + "--skip_existing", + dest="skip_existing", + action="store_true", + help="Skip image/config outputs when all expected CSVs already exist.", + ) + return parser + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + image_dir=None, + output_dir=None, + sections=None, + categories=None, + sub_categories=None, + model_ids=None, + topk=20, + standardize=False, + logit_scale=20.0, + device=None, + reduced=True, + label_delimiter="_", + label_token_index=0, + mapping_fn=None, + num_shards=1, + shard_index=0, + skip_existing=False, +): + import nibabel as nib + from nilearn.image import resample_to_img + from nimare.annotate.gclda import GCLDAModel + from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map + + from braindec.embedding import ImageEmbedding + from braindec.model import build_model + from braindec.predict import image_to_labels_hierarchical + from braindec.utils import _get_device, images_have_same_fov + + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + source = get_source(reduced) + voc_dir = op.join(data_dir, "vocabulary") + image_dir = op.join(data_dir, "cnp") if image_dir is None else op.abspath(image_dir) + output_dir = ( + op.join(results_dir, "predictions_cnp") if output_dir is None else op.abspath(output_dir) + ) + os.makedirs(output_dir, exist_ok=True) + + sections = list(DEFAULT_SECTIONS) if sections is None else sections + categories = ["task"] if categories is None else categories + sub_categories = ["combined"] if sub_categories is None else sub_categories + model_ids = list(DEFAULT_MODEL_IDS) if model_ids is None else model_ids + device = _get_device() if device is None else device + mask_img = nib.load(op.join(data_dir, "MNI152_2x2x2_brainmask.nii.gz")) + + def _resample_for_reference(image, reference_img): + if reference_img is None or images_have_same_fov(image, reference_img): + return image + return resample_to_img(image, reference_img) + + if num_shards < 1: + raise ValueError("--num_shards must be >= 1.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError("--shard_index must satisfy 0 <= shard_index < num_shards.") + + cognitiveatlas = build_cognitiveatlas(data_dir, reduced) + if mapping_fn is not None: + mapping_df = pd.read_csv(op.abspath(mapping_fn)) + if "local_path" not in mapping_df.columns: + raise KeyError(f"Mapping file {mapping_fn} must include a local_path column.") + images = sorted(mapping_df["local_path"].dropna().astype(str).tolist()) + else: + images = sorted(glob(op.join(image_dir, "*.nii.gz"))) + images = [img_fn for idx, img_fn in enumerate(images) if idx % num_shards == shard_index] + + image_records = [] + for img_fn in tqdm(images, desc="load cnp images", unit="img"): + prediction_label = infer_prediction_label( + img_fn, + delimiter=label_delimiter, + token_index=label_token_index, + ) + img = nib.load(img_fn) + if not images_have_same_fov(img, mask_img): + img = resample_to_img(img, mask_img) + image_records.append((prediction_label, img)) + + for section, category, sub_category, model_id in itertools.product( + sections, + categories, + sub_categories, + model_ids, + ): + resources = load_decoding_resources( + results_dir, + voc_dir, + source, + category, + sub_category, + model_id, + section, + ) + model = build_model(resources["model_path"], device=device) + image_emb_gene = ImageEmbedding( + standardize=standardize, + nilearn_dir=op.join(data_dir, "nilearn"), + space="MNI152", + ) + + gclda_model = None + ns_decoder = None + ns_mask = None + if sub_category == "names": + baseline_label = f"{source}-{category}_embedding-{resources['model_name']}_section-{section}" + ns_model_fn = op.join(results_dir, "baseline", f"model-neurosynth_{baseline_label}.pkl") + gclda_model_fn = op.join(results_dir, "baseline", f"model-gclda_{baseline_label}.pkl") + + gclda_model = GCLDAModel.load(gclda_model_fn) + ns_decoder = CorrelationDecoder.load(ns_model_fn) + ns_masker = None + if hasattr(ns_decoder, "results_"): + ns_masker = getattr(ns_decoder.results_, "masker", None) + if ns_masker is None and hasattr(ns_decoder, "masker"): + ns_masker = ns_decoder.masker + if ns_masker is not None and not hasattr(ns_masker, "clean_args_"): + clean_kwargs = getattr(ns_masker, "clean_kwargs", None) + ns_masker.clean_args_ = {} if clean_kwargs is None else clean_kwargs + if ns_masker is not None: + ns_mask = getattr(ns_masker, "mask_img", None) + if ns_mask is None: + ns_mask = getattr(ns_masker, "mask_img_", None) + + desc = f"decode cnp {section}/{category}/{sub_category}/{resources['model_name']}" + for prediction_label, img in tqdm(image_records, desc=desc, unit="img"): + file_base_name = f"{prediction_label}_{resources['vocabulary_label']}_section-{section}" + task_out_fn = f"{file_base_name}_pred-task_brainclip.csv" + concept_out_fn = f"{file_base_name}_pred-concept_brainclip.csv" + process_out_fn = f"{file_base_name}_pred-process_brainclip.csv" + expected_outputs = [ + op.join(output_dir, task_out_fn), + op.join(output_dir, concept_out_fn), + op.join(output_dir, process_out_fn), + ] + if sub_category == "names": + expected_outputs.extend( + [ + op.join(output_dir, f"{file_base_name}_pred-task_neurosynth.csv"), + op.join(output_dir, f"{file_base_name}_pred-task_gclda.csv"), + ] + ) + if skip_existing and all(op.exists(path) for path in expected_outputs): + continue + + task_prob_df, concept_prob_df, process_prob_df = image_to_labels_hierarchical( + img, + resources["model_path"], + resources["vocabulary"], + resources["vocabulary_emb"], + resources["vocabulary_prior"], + cognitiveatlas, + topk=topk, + standardize=standardize, + logit_scale=logit_scale, + data_dir=data_dir, + device=device, + model=model, + image_emb_gene=image_emb_gene, + ) + task_prob_df.to_csv(op.join(output_dir, task_out_fn), index=False) + concept_prob_df.to_csv(op.join(output_dir, concept_out_fn), index=False) + process_prob_df.to_csv(op.join(output_dir, process_out_fn), index=False) + + if sub_category != "names": + continue + + ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" + gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" + + gclda_img = _resample_for_reference(img, getattr(gclda_model, "mask", None)) + gclda_predictions_df, _ = gclda_decode_map(gclda_model, gclda_img) + gclda_predictions_df = gclda_predictions_df.sort_values(by="Weight", ascending=False).head( + topk + ) + gclda_predictions_df = gclda_predictions_df.reset_index() + gclda_predictions_df.columns = ["pred", "weight"] + gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) + + ns_img = _resample_for_reference(img, ns_mask) + ns_predictions_df = ns_decoder.transform(ns_img) + feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" + feature_names = ns_predictions_df.index.values + vocabulary_names = [f.replace(feature_group, "") for f in feature_names] + ns_predictions_df.index = vocabulary_names + ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head(topk) + ns_predictions_df = ns_predictions_df.reset_index() + ns_predictions_df.columns = ["pred", "corr"] + ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/decoding_eval.py b/jobs/decoding_eval.py index ad2a401..b1ee9a6 100644 --- a/jobs/decoding_eval.py +++ b/jobs/decoding_eval.py @@ -1,3 +1,6 @@ +"""Evaluate decoding predictions against explicit dataset mappings.""" + +import argparse import itertools import json import os @@ -7,9 +10,16 @@ import numpy as np import pandas as pd -from braindec.cogatlas import CognitiveAtlas +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + build_cognitiveatlas, + infer_prediction_label, + parse_name_list, + resolve_project_paths, +) -IMG_TO_DOMAIN = { +LEGACY_HCP_DOMAIN_TO_LABEL = { "EMOTION": "emotion", "GAMBLING": "gambling", "LANGUAGE": "language", @@ -20,167 +30,563 @@ } -def _recall_at_n(true_lb, pred_lb, n): - if isinstance(true_lb, int): - true_lb = [true_lb] - - # Check if empty - # if true_lb: - # print("Empty true labels") - # return 0 - - return len(np.intersect1d(true_lb, pred_lb[:n])) / len(true_lb) - - -def _get_cognitiveatlas(data_dir, reduced): - concept_to_task_fn = op.join(data_dir, "cognitive_atlas", "concept_to_task.json") - with open(concept_to_task_fn, "r") as file: - concept_to_task = json.load(file) - - concept_to_process_fn = op.join(data_dir, "cognitive_atlas", "concept_to_process.json") - with open(concept_to_process_fn, "r") as file: - concept_to_process = json.load(file) - - reduced_tasks_fn = op.join(data_dir, "cognitive_atlas", "reduced_tasks.csv") - reduced_tasks_df = pd.read_csv(reduced_tasks_fn) if reduced else None - - cognitiveatlas = CognitiveAtlas( - data_dir=data_dir, - task_snapshot=op.join(data_dir, "cognitive_atlas", "task_snapshot-02-19-25.json"), - concept_snapshot=op.join( - data_dir, "cognitive_atlas", "concept_extended_snapshot-02-19-25.json" - ), - reduced_tasks=reduced_tasks_df, - # concept_to_task=concept_to_task, - concept_to_process=concept_to_process, - ) - - return cognitiveatlas - - -def main(): - project_dir = "/Users/julioaperaza/Documents/GitHub/brain-decoder" - project_dir = op.abspath(project_dir) - data_dir = op.join(project_dir, "data") - - results_dir = op.join(project_dir, "results") - sections = ["body", "abstract"] - # sections = ["abstract", "body"] - sub_categories = ["combined", "names"] - # sub_categories = ["names", "definitions", "combined"] - categories = ["task"] # ["task", "concept"] - model_ids = [ - "BrainGPT/BrainGPT-7B-v0.2", - "mistralai/Mistral-7B-v0.1", - "BrainGPT/BrainGPT-7B-v0.1", - "meta-llama/Llama-2-7b-chat-hf", - ] - sources = ["cogatlasred", "cogatlas"] # "cogatlas" - models = [ - "neurosynth", - "gclda", - "brainclip", - ] - - domains = [ - "emotion", - "gambling", - "language", - "motor", - "relational", - "social", - "working memory", - ] - subdomains = ["task", "concept", "domain"] - - output_dir = op.join(results_dir, "predictions_hcp_nv") - os.makedirs(output_dir, exist_ok=True) - - image_dir = op.join(data_dir, "hcp", "neurovault") - images = sorted(glob(op.join(image_dir, "*.nii.gz"))) +def _recall_at_n(true_labels, pred_labels, n): + if not true_labels: + return np.nan + return len(set(true_labels) & set(pred_labels[:n])) / len(true_labels) + + +def _best_rank(true_labels, pred_labels): + true_labels = set(true_labels) + for rank, pred in enumerate(pred_labels, start=1): + if pred in true_labels: + return rank + return np.inf + + +def _resolve_column(df, requested_name, candidates, required=False): + names = [requested_name] if requested_name is not None else [] + names.extend(candidates) + for name in names: + if name and name in df.columns: + return name + if required: + raise KeyError(f"Could not find any of the required columns: {names}.") + return None + - ground_truth_fn = op.join(data_dir, "ibc", "ground_truth.json") - with open(ground_truth_fn, "r") as file: - ground_truth = json.load(file) - - results_dict = { - # "domain": [], - "model": [], - "task_gclda": [], - "task_neurosynth": [], - "task_brainclip": [], - "concept": [], - "process": [], +def _build_record_from_task(prediction_label, task_names, cognitiveatlas, dataset_name, row_id): + task_names = parse_name_list(task_names) + if not task_names: + raise ValueError(f"Missing task labels for record {row_id}.") + + task_idx = cognitiveatlas.get_task_idx_from_names(task_names) + task_idxs = [task_idx] if isinstance(task_idx, (int, np.integer)) else list(task_idx) + concept_idx = np.unique( + np.concatenate([cognitiveatlas.task_to_concept_idxs[idx] for idx in task_idxs]) + if task_idxs + else np.array([], dtype=int) + ) + concept_idx = concept_idx.astype(int, copy=False) + domain_idx = np.unique( + np.concatenate([cognitiveatlas.task_to_process_idxs[idx] for idx in task_idxs]) + if task_idxs + else np.array([], dtype=int) + ) + domain_idx = domain_idx.astype(int, copy=False) + concept_names = cognitiveatlas.get_concept_names_from_idx(concept_idx).tolist() + domain_names = cognitiveatlas.get_process_names_from_idx(domain_idx).tolist() + return { + "dataset": dataset_name, + "prediction_label": prediction_label, + "task": task_names, + "concept": concept_names, + "domain": domain_names, } - for section, model_id, source, category, sub_category in itertools.product( - sections, model_ids, sources, categories, sub_categories - ): - model_name = model_id.split("/")[-1] - reduced = True if source == "cogatlasred" else False - cognitiveatlas = _get_cognitiveatlas(data_dir, reduced) - vocabulary_lb = f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}_section-{section}" - results_dict["model"].append(vocabulary_lb) - for model in models: - # results_dict["model"].append(f"{vocabulary_lb}_{model}") +def _load_mapping_records( + dataset_name, + mapping_fn, + mapping_label_column, + mapping_task_column, + mapping_concepts_column, + mapping_domains_column, + mapping_filename_column, + mapping_label_delimiter, + mapping_label_token_index, + cognitiveatlas, +): + mapping_df = pd.read_csv(mapping_fn) + label_column = _resolve_column( + mapping_df, + mapping_label_column, + ["prediction_label", "task_code", "image_label", "label"], + required=False, + ) + filename_column = _resolve_column( + mapping_df, + mapping_filename_column, + ["filename", "image_path", "image", "path"], + required=label_column is None, + ) + task_column = _resolve_column( + mapping_df, + mapping_task_column, + ["task_name", "task"], + required=True, + ) + concepts_column = _resolve_column( + mapping_df, + mapping_concepts_column, + ["concepts", "concept", "true_concepts"], + required=False, + ) + domains_column = _resolve_column( + mapping_df, + mapping_domains_column, + ["domains", "domain", "true_domains"], + required=False, + ) + + records = [] + for row_idx, row in mapping_df.iterrows(): + if label_column is not None: + prediction_label = str(row[label_column]).strip() + else: + prediction_label = infer_prediction_label( + row[filename_column], + delimiter=mapping_label_delimiter, + token_index=mapping_label_token_index, + ) + + task_names = parse_name_list(row[task_column]) + if not task_names: + raise ValueError(f"Row {row_idx} in {mapping_fn} does not define any task labels.") + + if concepts_column is None or domains_column is None: + if cognitiveatlas is None: + raise ValueError( + "A CognitiveAtlas object is required when the mapping file does not define " + "explicit concepts/domains columns." + ) + record = _build_record_from_task( + prediction_label, + task_names, + cognitiveatlas, + dataset_name=dataset_name, + row_id=f"{mapping_fn}:{row_idx}", + ) + else: + record = { + "dataset": dataset_name, + "prediction_label": prediction_label, + "task": task_names, + "concept": parse_name_list(row[concepts_column]), + "domain": parse_name_list(row[domains_column]), + } + + records.append(record) + + return records + + +def _load_legacy_hcp_records(image_dir, ground_truth_fn): + with open(ground_truth_fn, "r") as file_obj: + ground_truth = json.load(file_obj) - if model != "brainclip" and sub_category != "names": - results_dict[f"task_{model}"].append(np.nan) + images = sorted(glob(op.join(image_dir, "*.nii.gz"))) + records = [] + for img_fn in images: + image_name = op.basename(img_fn).split(".")[0] + task_code = image_name.split("_")[1] + domain_label = LEGACY_HCP_DOMAIN_TO_LABEL[task_code] + gt = ground_truth.get(domain_label) or ground_truth.get(domain_label.replace(" ", "_")) + if gt is None: + raise KeyError(f"Could not find ground truth for HCP domain {domain_label!r}.") + records.append( + { + "dataset": "hcp", + "prediction_label": task_code, + "task": parse_name_list(gt["task"]), + "concept": parse_name_list(gt["concept"]), + "domain": parse_name_list(gt["domain"]), + } + ) + return records + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Evaluate decoding predictions against ground truth.") + parser.add_argument( + "--project_dir", + dest="project_dir", + default=None, + help="Path to the repository root.", + ) + parser.add_argument( + "--data_dir", + dest="data_dir", + default=None, + help="Optional explicit data directory.", + ) + parser.add_argument( + "--results_dir", + dest="results_dir", + default=None, + help="Optional explicit results directory.", + ) + parser.add_argument( + "--dataset_name", + dest="dataset_name", + default="hcp", + help="Dataset label to attach to evaluation rows.", + ) + parser.add_argument( + "--sections", + dest="sections", + nargs="+", + default=list(DEFAULT_SECTIONS), + help="Text sections to evaluate.", + ) + parser.add_argument( + "--model_ids", + dest="model_ids", + nargs="+", + default=list(DEFAULT_MODEL_IDS), + help="One or more embedding model identifiers to evaluate.", + ) + parser.add_argument( + "--sources", + dest="sources", + nargs="+", + default=["cogatlasred", "cogatlas"], + help="Vocabulary sources to evaluate.", + ) + parser.add_argument( + "--categories", + dest="categories", + nargs="+", + default=["task"], + help="Vocabulary categories to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined", "names"], + help="Vocabulary embedding variants to evaluate.", + ) + parser.add_argument( + "--models", + dest="models", + nargs="+", + default=["neurosynth", "gclda", "brainclip"], + help="Prediction backends expected in the output directory.", + ) + parser.add_argument( + "--prediction_dir", + dest="prediction_dir", + default=None, + help="Optional explicit prediction directory.", + ) + parser.add_argument( + "--mapping_fn", + dest="mapping_fn", + default=None, + help="CSV describing evaluation items. Preferred for IBC/CNP and explicit HCP mappings.", + ) + parser.add_argument( + "--mapping_label_column", + dest="mapping_label_column", + default=None, + help="Optional mapping CSV column containing the prediction label prefix.", + ) + parser.add_argument( + "--mapping_task_column", + dest="mapping_task_column", + default=None, + help="Optional mapping CSV column containing task labels.", + ) + parser.add_argument( + "--mapping_concepts_column", + dest="mapping_concepts_column", + default=None, + help="Optional mapping CSV column containing concept labels.", + ) + parser.add_argument( + "--mapping_domains_column", + dest="mapping_domains_column", + default=None, + help="Optional mapping CSV column containing domain labels.", + ) + parser.add_argument( + "--mapping_filename_column", + dest="mapping_filename_column", + default=None, + help="Optional mapping CSV column used to derive the prediction label when no label column exists.", + ) + parser.add_argument( + "--mapping_label_delimiter", + dest="mapping_label_delimiter", + default="_", + help="Delimiter used when deriving prediction labels from filenames.", + ) + parser.add_argument( + "--mapping_label_token_index", + dest="mapping_label_token_index", + type=int, + default=None, + help="Optional token index used when deriving prediction labels from filenames.", + ) + parser.add_argument( + "--image_dir", + dest="image_dir", + default=None, + help="Legacy HCP fallback: path to evaluation images used to infer labels from filenames.", + ) + parser.add_argument( + "--ground_truth_fn", + dest="ground_truth_fn", + default=None, + help="Legacy HCP fallback: path to the HCP ground-truth JSON mapping file.", + ) + parser.add_argument( + "--task_k", + dest="task_k", + type=int, + default=4, + help="Recall@K cutoff for task predictions.", + ) + parser.add_argument( + "--concept_k", + dest="concept_k", + type=int, + default=4, + help="Recall@K cutoff for concept predictions.", + ) + parser.add_argument( + "--domain_k", + dest="domain_k", + type=int, + default=2, + help="Recall@K cutoff for domain predictions.", + ) + parser.add_argument( + "--output_fn", + dest="output_fn", + default=None, + help="Optional explicit aggregate output CSV path.", + ) + parser.add_argument( + "--details_output_fn", + dest="details_output_fn", + default=None, + help="Optional explicit detailed output CSV path.", + ) + return parser + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + dataset_name="hcp", + sections=None, + model_ids=None, + sources=None, + categories=None, + sub_categories=None, + models=None, + prediction_dir=None, + mapping_fn=None, + mapping_label_column=None, + mapping_task_column=None, + mapping_concepts_column=None, + mapping_domains_column=None, + mapping_filename_column=None, + mapping_label_delimiter="_", + mapping_label_token_index=None, + image_dir=None, + ground_truth_fn=None, + task_k=4, + concept_k=4, + domain_k=2, + output_fn=None, + details_output_fn=None, +): + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + sections = list(DEFAULT_SECTIONS) if sections is None else sections + model_ids = list(DEFAULT_MODEL_IDS) if model_ids is None else model_ids + sources = ["cogatlasred", "cogatlas"] if sources is None else sources + categories = ["task"] if categories is None else categories + sub_categories = ["combined", "names"] if sub_categories is None else sub_categories + models = ["neurosynth", "gclda", "brainclip"] if models is None else models + + prediction_dir = ( + op.join(results_dir, f"predictions_{dataset_name}") if prediction_dir is None else op.abspath(prediction_dir) + ) + image_dir = op.join(data_dir, "hcp", "neurovault") if image_dir is None else op.abspath(image_dir) + ground_truth_fn = ( + op.join(data_dir, "hcp", "ground_truth.json") + if ground_truth_fn is None + else op.abspath(ground_truth_fn) + ) + output_fn = ( + op.join(results_dir, f"eval-{dataset_name}_results.csv") if output_fn is None else op.abspath(output_fn) + ) + details_output_fn = ( + op.join(results_dir, f"eval-{dataset_name}_details.csv") + if details_output_fn is None + else op.abspath(details_output_fn) + ) + os.makedirs(op.dirname(output_fn), exist_ok=True) + os.makedirs(op.dirname(details_output_fn), exist_ok=True) + + detailed_rows = [] + aggregate_rows = [] + + for section, model_id, source, category, sub_category in itertools.product( + sections, + model_ids, + sources, + categories, + sub_categories, + ): + reduced = source == "cogatlasred" + cognitiveatlas = None + if mapping_fn is not None: + mapping_preview_df = pd.read_csv(op.abspath(mapping_fn), nrows=1) + preview_concepts_column = _resolve_column( + mapping_preview_df, + mapping_concepts_column, + ["concepts", "concept", "true_concepts"], + required=False, + ) + preview_domains_column = _resolve_column( + mapping_preview_df, + mapping_domains_column, + ["domains", "domain", "true_domains"], + required=False, + ) + needs_cognitiveatlas = preview_concepts_column is None or preview_domains_column is None + else: + needs_cognitiveatlas = True + + if needs_cognitiveatlas: + concept_to_process_fn = op.join(data_dir, "cognitive_atlas", "concept_to_process.json") + cognitiveatlas = build_cognitiveatlas(data_dir, reduced, concept_to_process_fn) + model_name = model_id.split("/")[-1] + vocabulary_label = ( + f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}_section-{section}" + ) + + if mapping_fn is not None: + records = _load_mapping_records( + dataset_name=dataset_name, + mapping_fn=op.abspath(mapping_fn), + mapping_label_column=mapping_label_column, + mapping_task_column=mapping_task_column, + mapping_concepts_column=mapping_concepts_column, + mapping_domains_column=mapping_domains_column, + mapping_filename_column=mapping_filename_column, + mapping_label_delimiter=mapping_label_delimiter, + mapping_label_token_index=mapping_label_token_index, + cognitiveatlas=cognitiveatlas, + ) + else: + if not op.exists(ground_truth_fn): + raise FileNotFoundError( + f"Ground-truth file not found at {ground_truth_fn}. Pass --mapping_fn or --ground_truth_fn." + ) + records = _load_legacy_hcp_records(image_dir=image_dir, ground_truth_fn=ground_truth_fn) + + for backend in models: + if backend != "brainclip" and sub_category != "names": continue - temp_results = {dom: {subdom: [] for subdom in subdomains} for dom in domains} - for _, img_fn in enumerate(images): - image_name = op.basename(img_fn).split(".")[0] - task_name = image_name.split("_")[1] - file_lb = f"{task_name}_{vocabulary_lb}" - - domain = IMG_TO_DOMAIN[task_name] - task_true = ground_truth[domain]["task"] - - task_true_idx = cognitiveatlas.get_task_idx_from_names(task_true) - - task_out_fn = f"{file_lb}_pred-task_{model}.csv" - task_prob_df = pd.read_csv(op.join(output_dir, task_out_fn)) - task_pred = task_prob_df["pred"].values - task_pred = task_pred[:5] - task_pred_idx = cognitiveatlas.get_task_idx_from_names(task_pred) - task_recall = _recall_at_n(task_true_idx, task_pred_idx, 4) - temp_results[domain]["task"].append(task_recall) - - if model == "brainclip": - concept_out_fn = f"{file_lb}_pred-concept_{model}.csv" - process_out_fn = f"{file_lb}_pred-process_{model}.csv" - - concept_true = ground_truth[domain]["concept"] - process_true = ground_truth[domain]["domain"] - concept_true_idx = cognitiveatlas.get_concept_idx_from_names(concept_true) - process_true_idx = cognitiveatlas.get_process_idx_from_names(process_true) - - concept_prob_df = pd.read_csv(op.join(output_dir, concept_out_fn)) - process_prob_df = pd.read_csv(op.join(output_dir, process_out_fn)) - concept_pred = concept_prob_df["pred"].values - concept_pred_idx = cognitiveatlas.get_concept_idx_from_names(concept_pred) - process_pred = process_prob_df["pred"].values - process_pred_idx = cognitiveatlas.get_process_idx_from_names(process_pred) - concept_recall = _recall_at_n(concept_true_idx, concept_pred_idx, 4) - process_recalls = _recall_at_n(process_true_idx, process_pred_idx, 2) - temp_results[domain]["concept"].append(concept_recall) - temp_results[domain]["domain"].append(process_recalls) - - mean_task_recalls = np.mean([temp_results[dom]["task"] for dom in domains]) - results_dict[f"task_{model}"].append(mean_task_recalls) - - if model == "brainclip": - mean_concept_recalls = np.mean([temp_results[dom]["concept"] for dom in domains]) - mean_process_recalls = np.mean([temp_results[dom]["domain"] for dom in domains]) - results_dict["concept"].append(mean_concept_recalls) - results_dict["process"].append(mean_process_recalls) - - # Export results to csv - results_df = pd.DataFrame(results_dict) - results_df.to_csv(op.join(results_dir, "eval-hcp-group_results.csv"), index=False) + backend_level_rows = [] + for record in records: + file_base = f"{record['prediction_label']}_{vocabulary_label}" + + task_path = op.join(prediction_dir, f"{file_base}_pred-task_{backend}.csv") + if not op.exists(task_path): + raise FileNotFoundError(f"Prediction file not found: {task_path}") + task_pred_df = pd.read_csv(task_path) + task_preds = task_pred_df["pred"].tolist() + task_rank = _best_rank(record["task"], task_preds) + task_recall = _recall_at_n(record["task"], task_preds, task_k) + task_row = { + "dataset": record["dataset"], + "prediction_label": record["prediction_label"], + "source": source, + "category": category, + "sub_category": sub_category, + "section": section, + "model_id": model_id, + "model_name": model_name, + "vocabulary_label": vocabulary_label, + "backend": backend, + "level": "task", + "k": task_k, + "num_true_labels": len(record["task"]), + "recall_at_k": task_recall, + "hit_at_k": float(task_rank <= task_k), + "best_rank": task_rank, + "true_labels_json": json.dumps(record["task"]), + "top_predictions_json": json.dumps(task_preds), + } + detailed_rows.append(task_row) + backend_level_rows.append(task_row) + + if backend != "brainclip": + continue + + level_specs = [ + ("concept", concept_k, record["concept"], op.join(prediction_dir, f"{file_base}_pred-concept_{backend}.csv")), + ("domain", domain_k, record["domain"], op.join(prediction_dir, f"{file_base}_pred-process_{backend}.csv")), + ] + for level_name, k_value, true_labels, pred_path in level_specs: + if not op.exists(pred_path): + raise FileNotFoundError(f"Prediction file not found: {pred_path}") + pred_df = pd.read_csv(pred_path) + pred_labels = pred_df["pred"].tolist() + best_rank = _best_rank(true_labels, pred_labels) + recall = _recall_at_n(true_labels, pred_labels, k_value) + level_row = { + "dataset": record["dataset"], + "prediction_label": record["prediction_label"], + "source": source, + "category": category, + "sub_category": sub_category, + "section": section, + "model_id": model_id, + "model_name": model_name, + "vocabulary_label": vocabulary_label, + "backend": backend, + "level": level_name, + "k": k_value, + "num_true_labels": len(true_labels), + "recall_at_k": recall, + "hit_at_k": float(best_rank <= k_value), + "best_rank": best_rank, + "true_labels_json": json.dumps(true_labels), + "top_predictions_json": json.dumps(pred_labels), + } + detailed_rows.append(level_row) + backend_level_rows.append(level_row) + + backend_df = pd.DataFrame(backend_level_rows) + for level_name, level_df in backend_df.groupby("level", sort=False): + finite_ranks = level_df["best_rank"].replace(np.inf, np.nan) + aggregate_rows.append( + { + "dataset": dataset_name, + "source": source, + "category": category, + "sub_category": sub_category, + "section": section, + "model_id": model_id, + "model_name": model_name, + "vocabulary_label": vocabulary_label, + "backend": backend, + "level": level_name, + "k": int(level_df["k"].iloc[0]), + "n_images": int(len(level_df)), + "mean_recall_at_k": float(level_df["recall_at_k"].mean()), + "mean_hit_at_k": float(level_df["hit_at_k"].mean()), + "median_best_rank": float(np.nanmedian(finite_ranks)), + } + ) + + pd.DataFrame(aggregate_rows).to_csv(output_fn, index=False) + pd.DataFrame(detailed_rows).to_csv(details_output_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) if __name__ == "__main__": - main() + _main() diff --git a/jobs/decoding_hcp_nv.py b/jobs/decoding_hcp_nv.py index 31c9a5b..2e4eb11 100644 --- a/jobs/decoding_hcp_nv.py +++ b/jobs/decoding_hcp_nv.py @@ -1,68 +1,99 @@ +import argparse import itertools import os import os.path as op from glob import glob -import nibabel as nib -import pandas as pd -from nilearn._utils.niimg_conversions import check_same_fov -from nilearn.image import resample_to_img -from nimare.annotate.gclda import GCLDAModel -from nimare.dataset import Dataset -from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map -from utils import _read_vocabulary - -from braindec.cogatlas import CognitiveAtlas -from braindec.plot import plot_surf -from braindec.predict import image_to_labels_hierarchical - - -def main(): - project_dir = "/Users/julioaperaza/Documents/GitHub/brain-decoder" - project_dir = op.abspath(project_dir) - data_dir = op.join(project_dir, "data") - reduced = True +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + add_common_job_args, + build_cognitiveatlas, + get_source, + load_decoding_resources, + resolve_project_paths, +) + + +def _get_parser(): + parser = add_common_job_args(argparse.ArgumentParser(description="Decode HCP NeuroVault maps")) + parser.add_argument( + "--image_dir", + dest="image_dir", + default=None, + help="Optional explicit path to the HCP NeuroVault image directory.", + ) + parser.add_argument( + "--output_dir", + dest="output_dir", + default=None, + help="Optional explicit path to the output directory.", + ) + parser.add_argument( + "--categories", + dest="categories", + nargs="+", + default=["task"], + help="Vocabulary categories to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined"], + help="Vocabulary embedding variants to evaluate.", + ) + return parser + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + image_dir=None, + output_dir=None, + sections=None, + categories=None, + sub_categories=None, + model_ids=None, + topk=20, + standardize=False, + logit_scale=20.0, + device=None, + reduced=True, +): + import nibabel as nib + from nilearn.image import resample_to_img + from nimare.annotate.gclda import GCLDAModel + from nimare.dataset import Dataset + from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map + + from braindec.plot import plot_surf + from braindec.predict import image_to_labels_hierarchical + from braindec.utils import _get_device, images_have_same_fov + + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + source = get_source(reduced) voc_dir = op.join(data_dir, "vocabulary") - source = "cogatlasred" if reduced else "cogatlas" - results_dir = op.join(project_dir, "results") - sections = ["abstract", "body"] - sub_categories = ["combined"] - categories = ["task"] - model_ids = [ - "BrainGPT/BrainGPT-7B-v0.2", - "mistralai/Mistral-7B-v0.1", - "BrainGPT/BrainGPT-7B-v0.1", - "meta-llama/Llama-2-7b-chat-hf", - ] - topk = 20 - standardize = False - logit_scale = 20 - device = "mps" - - output_dir = op.join(results_dir, "predictions_hcp_nv") + image_dir = op.join(data_dir, "hcp", "neurovault") if image_dir is None else op.abspath(image_dir) + output_dir = ( + op.join(results_dir, "predictions_hcp_nv") if output_dir is None else op.abspath(output_dir) + ) os.makedirs(output_dir, exist_ok=True) - dset = Dataset.load(op.join(data_dir, f"dset-pubmed_{source}-annotated_nimare.pkl")) - - reduced_tasks_fn = op.join(data_dir, "cognitive_atlas", "reduced_tasks.csv") - reduced_tasks_df = pd.read_csv(reduced_tasks_fn) if reduced else None - - cognitiveatlas = CognitiveAtlas( - data_dir=data_dir, - task_snapshot=op.join(data_dir, "cognitive_atlas", "task_snapshot-02-19-25.json"), - concept_snapshot=op.join( - data_dir, "cognitive_atlas", "concept_extended_snapshot-02-19-25.json" - ), - reduced_tasks=reduced_tasks_df, - ) + sections = list(DEFAULT_SECTIONS) if sections is None else sections + categories = ["task"] if categories is None else categories + sub_categories = ["combined"] if sub_categories is None else sub_categories + model_ids = list(DEFAULT_MODEL_IDS) if model_ids is None else model_ids + device = _get_device() if device is None else device - image_dir = op.join(data_dir, "hcp", "neurovault") + dset = Dataset.load(op.join(data_dir, f"dset-pubmed_{source}-annotated_nimare.pkl")) + cognitiveatlas = build_cognitiveatlas(data_dir, reduced) images = sorted(glob(op.join(image_dir, "*.nii.gz"))) - for _, img_fn in enumerate(images): + for img_fn in images: image_name = op.basename(img_fn).split(".")[0] task_name = image_name.split("_")[1] - # Plot map for debugging plot_surf( img_fn, op.join(output_dir, f"{task_name}_map.png"), @@ -71,39 +102,36 @@ def main(): ) img = nib.load(img_fn) - if not check_same_fov(img, reference_masker=dset.masker.mask_img): + if not images_have_same_fov(img, dset.masker.mask_img): img = resample_to_img(img, dset.masker.mask_img) for section, category, sub_category, model_id in itertools.product( - sections, categories, sub_categories, model_ids + sections, + categories, + sub_categories, + model_ids, ): - model_name = model_id.split("/")[-1] - model_path = op.join( + resources = load_decoding_resources( results_dir, - "pubmed", - f"model-clip_section-{section}_embedding-{model_name}_best.pth", - ) - vocabulary_lb = f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}" - vocabulary_fn = op.join(voc_dir, f"vocabulary-{source}_{category}.txt") - vocabulary_emb_fn = op.join(voc_dir, f"{vocabulary_lb}.npy") - vocabulary_prior_fn = op.join(voc_dir, f"{vocabulary_lb}_section-{section}_prior.npy") - vocabulary, vocabulary_emb, vocabulary_prior = _read_vocabulary( - vocabulary_fn, - vocabulary_emb_fn, - vocabulary_prior_fn, + voc_dir, + source, + category, + sub_category, + model_id, + section, ) - file_base_name = f"{task_name}_{vocabulary_lb}_section-{section}" + file_base_name = f"{task_name}_{resources['vocabulary_label']}_section-{section}" task_out_fn = f"{file_base_name}_pred-task_brainclip.csv" concept_out_fn = f"{file_base_name}_pred-concept_brainclip.csv" process_out_fn = f"{file_base_name}_pred-process_brainclip.csv" task_prob_df, concept_prob_df, process_prob_df = image_to_labels_hierarchical( img, - model_path, - vocabulary, - vocabulary_emb, - vocabulary_prior, + resources["model_path"], + resources["vocabulary"], + resources["vocabulary_emb"], + resources["vocabulary_prior"], cognitiveatlas, topk=topk, standardize=standardize, @@ -115,48 +143,41 @@ def main(): concept_prob_df.to_csv(op.join(output_dir, concept_out_fn), index=False) process_prob_df.to_csv(op.join(output_dir, process_out_fn), index=False) - if sub_category == "names": - baseline_label = f"{source}-{category}_embedding-{model_name}_section-{section}" - # Baseline - # -------------------------------------------------------------------- - ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" - gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" - - # Load baseline model - ns_model_fn = op.join( - results_dir, - "baseline", - f"model-neurosynth_{baseline_label}.pkl", - ) - gclda_model_fn = op.join( - results_dir, - "baseline", - f"model-gclda_{baseline_label}.pkl", - ) - - gclda_model = GCLDAModel.load(gclda_model_fn) - gclda_predictions_df, _ = gclda_decode_map(gclda_model, img) - gclda_predictions_df = gclda_predictions_df.sort_values( - by="Weight", ascending=False - ).head(topk) - gclda_predictions_df = gclda_predictions_df.reset_index() - gclda_predictions_df.columns = ["pred", "weight"] - gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) - - ns_decoder = CorrelationDecoder.load(ns_model_fn) - ns_predictions_df = ns_decoder.transform(img) - feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" - feature_names = ns_predictions_df.index.values - vocabulary_names = [f.replace(feature_group, "") for f in feature_names] - ns_predictions_df.index = vocabulary_names - - ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head( - topk - ) - ns_predictions_df = ns_predictions_df.reset_index() - ns_predictions_df.columns = ["pred", "corr"] - ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + if sub_category != "names": + continue + + baseline_label = f"{source}-{category}_embedding-{resources['model_name']}_section-{section}" + ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" + gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" + + ns_model_fn = op.join(results_dir, "baseline", f"model-neurosynth_{baseline_label}.pkl") + gclda_model_fn = op.join(results_dir, "baseline", f"model-gclda_{baseline_label}.pkl") + + gclda_model = GCLDAModel.load(gclda_model_fn) + gclda_predictions_df, _ = gclda_decode_map(gclda_model, img) + gclda_predictions_df = gclda_predictions_df.sort_values(by="Weight", ascending=False).head( + topk + ) + gclda_predictions_df = gclda_predictions_df.reset_index() + gclda_predictions_df.columns = ["pred", "weight"] + gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) + + ns_decoder = CorrelationDecoder.load(ns_model_fn) + ns_predictions_df = ns_decoder.transform(img) + feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" + feature_names = ns_predictions_df.index.values + vocabulary_names = [f.replace(feature_group, "") for f in feature_names] + ns_predictions_df.index = vocabulary_names + ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head(topk) + ns_predictions_df = ns_predictions_df.reset_index() + ns_predictions_df.columns = ["pred", "corr"] + ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) if __name__ == "__main__": - main() + _main() diff --git a/jobs/decoding_ibc.py b/jobs/decoding_ibc.py index 905ba2b..3602ee5 100644 --- a/jobs/decoding_ibc.py +++ b/jobs/decoding_ibc.py @@ -1,164 +1,287 @@ +import argparse import itertools import os import os.path as op from glob import glob -import nibabel as nib import pandas as pd -from nilearn._utils.niimg_conversions import check_same_fov -from nilearn.image import resample_to_img -from nimare.annotate.gclda import GCLDAModel -from nimare.dataset import Dataset -from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map -from utils import _read_vocabulary - -from braindec.cogatlas import CognitiveAtlas -from braindec.plot import plot_surf -from braindec.predict import image_to_labels_hierarchical - - -def main(): - project_dir = "/Users/julioaperaza/Documents/GitHub/brain-decoder" - project_dir = op.abspath(project_dir) - data_dir = op.join(project_dir, "data") - reduced = True +from tqdm.auto import tqdm + +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + add_common_job_args, + build_cognitiveatlas, + get_source, + load_decoding_resources, + resolve_project_paths, +) + + +def _get_parser(): + parser = add_common_job_args(argparse.ArgumentParser(description="Decode IBC maps")) + parser.add_argument( + "--image_dir", + dest="image_dir", + default=None, + help="Optional explicit path to the IBC image directory.", + ) + parser.add_argument( + "--output_dir", + dest="output_dir", + default=None, + help="Optional explicit path to the output directory.", + ) + parser.add_argument( + "--categories", + dest="categories", + nargs="+", + default=["task"], + help="Vocabulary categories to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined"], + help="Vocabulary embedding variants to evaluate.", + ) + parser.add_argument( + "--make_plots", + dest="make_plots", + action="store_true", + help="Write surface renderings for each input image.", + ) + parser.add_argument( + "--mapping_fn", + dest="mapping_fn", + default=None, + help="Optional mapping CSV with a local_path column used to select images.", + ) + parser.add_argument( + "--num_shards", + dest="num_shards", + type=int, + default=1, + help="Split the selected image list into this many shards.", + ) + parser.add_argument( + "--shard_index", + dest="shard_index", + type=int, + default=0, + help="0-based shard index to decode from the sharded image list.", + ) + parser.add_argument( + "--skip_existing", + dest="skip_existing", + action="store_true", + help="Skip image/config outputs when all expected CSVs already exist.", + ) + return parser + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + image_dir=None, + output_dir=None, + sections=None, + categories=None, + sub_categories=None, + model_ids=None, + topk=20, + standardize=False, + logit_scale=20.0, + device=None, + reduced=True, + make_plots=False, + mapping_fn=None, + num_shards=1, + shard_index=0, + skip_existing=False, +): + import nibabel as nib + from nilearn.image import resample_to_img + from nimare.annotate.gclda import GCLDAModel + from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map + + from braindec.embedding import ImageEmbedding + from braindec.model import build_model + from braindec.predict import image_to_labels_hierarchical + from braindec.utils import _get_device, images_have_same_fov + + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + source = get_source(reduced) voc_dir = op.join(data_dir, "vocabulary") - source = "cogatlasred" if reduced else "cogatlas" - results_dir = op.join(project_dir, "results") - sections = ["abstract", "body"] - sub_categories = ["combined"] - categories = ["task"] - model_ids = [ - "BrainGPT/BrainGPT-7B-v0.2", - "mistralai/Mistral-7B-v0.1", - "BrainGPT/BrainGPT-7B-v0.1", - "meta-llama/Llama-2-7b-chat-hf", - ] - topk = 20 - standardize = False - logit_scale = 20 - device = "mps" - - output_dir = op.join(results_dir, "predictions_ibc") + image_dir = op.join(data_dir, "ibc") if image_dir is None else op.abspath(image_dir) + output_dir = ( + op.join(results_dir, "predictions_ibc") if output_dir is None else op.abspath(output_dir) + ) os.makedirs(output_dir, exist_ok=True) - dset = Dataset.load(op.join(data_dir, f"dset-pubmed_{source}-annotated_nimare.pkl")) + sections = list(DEFAULT_SECTIONS) if sections is None else sections + categories = ["task"] if categories is None else categories + sub_categories = ["combined"] if sub_categories is None else sub_categories + model_ids = list(DEFAULT_MODEL_IDS) if model_ids is None else model_ids + device = _get_device() if device is None else device + mask_img = nib.load(op.join(data_dir, "MNI152_2x2x2_brainmask.nii.gz")) - reduced_tasks_fn = op.join(data_dir, "cognitive_atlas", "reduced_tasks.csv") - reduced_tasks_df = pd.read_csv(reduced_tasks_fn) if reduced else None + def _resample_for_reference(image, reference_img): + if reference_img is None or images_have_same_fov(image, reference_img): + return image + return resample_to_img(image, reference_img) - cognitiveatlas = CognitiveAtlas( - data_dir=data_dir, - task_snapshot=op.join(data_dir, "cognitive_atlas", "task_snapshot-02-19-25.json"), - concept_snapshot=op.join( - data_dir, "cognitive_atlas", "concept_extended_snapshot-02-19-25.json" - ), - reduced_tasks=reduced_tasks_df, - ) + if num_shards < 1: + raise ValueError("--num_shards must be >= 1.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError("--shard_index must satisfy 0 <= shard_index < num_shards.") - image_dir = op.join(data_dir, "ibc") - images = sorted(glob(op.join(image_dir, "*.nii.gz"))) # [:10] + cognitiveatlas = build_cognitiveatlas(data_dir, reduced) + if mapping_fn is not None: + mapping_df = pd.read_csv(op.abspath(mapping_fn)) + if "local_path" not in mapping_df.columns: + raise KeyError(f"Mapping file {mapping_fn} must include a local_path column.") + images = sorted(mapping_df["local_path"].dropna().astype(str).tolist()) + else: + images = sorted(glob(op.join(image_dir, "*.nii.gz"))) + images = [img_fn for idx, img_fn in enumerate(images) if idx % num_shards == shard_index] - for _, img_fn in enumerate(images): + image_records = [] + for img_fn in tqdm(images, desc="load ibc images", unit="img"): image_name = op.basename(img_fn).split(".")[0] task_name = image_name.split("_")[0] - # Plot map for debugging - plot_surf( - img_fn, - op.join(output_dir, f"{task_name}_map.png"), - vmax=None, - ) + if make_plots: + from braindec.plot import plot_surf - img = nib.load(img_fn) - if not check_same_fov(img, reference_masker=dset.masker.mask_img): - img = resample_to_img(img, dset.masker.mask_img) - - for section, category, sub_category, model_id in itertools.product( - sections, categories, sub_categories, model_ids - ): - model_name = model_id.split("/")[-1] - model_path = op.join( - results_dir, - "pubmed", - f"model-clip_section-{section}_embedding-{model_name}_best.pth", - ) - vocabulary_lb = f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}" - vocabulary_fn = op.join(voc_dir, f"vocabulary-{source}_{category}.txt") - vocabulary_emb_fn = op.join(voc_dir, f"{vocabulary_lb}.npy") - vocabulary_prior_fn = op.join(voc_dir, f"{vocabulary_lb}_section-{section}_prior.npy") - vocabulary, vocabulary_emb, vocabulary_prior = _read_vocabulary( - vocabulary_fn, - vocabulary_emb_fn, - vocabulary_prior_fn, + plot_surf( + img_fn, + op.join(output_dir, f"{task_name}_map.png"), + vmax=None, ) - file_base_name = f"{task_name}_{vocabulary_lb}_section-{section}" + img = nib.load(img_fn) + if not images_have_same_fov(img, mask_img): + img = resample_to_img(img, mask_img) + image_records.append((task_name, img)) + + for section, category, sub_category, model_id in itertools.product( + sections, + categories, + sub_categories, + model_ids, + ): + resources = load_decoding_resources( + results_dir, + voc_dir, + source, + category, + sub_category, + model_id, + section, + ) + model = build_model(resources["model_path"], device=device) + image_emb_gene = ImageEmbedding( + standardize=standardize, + nilearn_dir=op.join(data_dir, "nilearn"), + space="MNI152", + ) + + gclda_model = None + ns_decoder = None + ns_mask = None + if sub_category == "names": + baseline_label = f"{source}-{category}_embedding-{resources['model_name']}_section-{section}" + ns_model_fn = op.join(results_dir, "baseline", f"model-neurosynth_{baseline_label}.pkl") + gclda_model_fn = op.join(results_dir, "baseline", f"model-gclda_{baseline_label}.pkl") + + gclda_model = GCLDAModel.load(gclda_model_fn) + ns_decoder = CorrelationDecoder.load(ns_model_fn) + ns_masker = None + if hasattr(ns_decoder, "results_"): + ns_masker = getattr(ns_decoder.results_, "masker", None) + if ns_masker is None and hasattr(ns_decoder, "masker"): + ns_masker = ns_decoder.masker + if ns_masker is not None and not hasattr(ns_masker, "clean_args_"): + clean_kwargs = getattr(ns_masker, "clean_kwargs", None) + ns_masker.clean_args_ = {} if clean_kwargs is None else clean_kwargs + if ns_masker is not None: + ns_mask = getattr(ns_masker, "mask_img", None) + if ns_mask is None: + ns_mask = getattr(ns_masker, "mask_img_", None) + + desc = f"decode ibc {section}/{category}/{sub_category}/{resources['model_name']}" + for task_name, img in tqdm(image_records, desc=desc, unit="img"): + file_base_name = f"{task_name}_{resources['vocabulary_label']}_section-{section}" task_out_fn = f"{file_base_name}_pred-task_brainclip.csv" concept_out_fn = f"{file_base_name}_pred-concept_brainclip.csv" process_out_fn = f"{file_base_name}_pred-process_brainclip.csv" + expected_outputs = [ + op.join(output_dir, task_out_fn), + op.join(output_dir, concept_out_fn), + op.join(output_dir, process_out_fn), + ] + if sub_category == "names": + expected_outputs.extend( + [ + op.join(output_dir, f"{file_base_name}_pred-task_neurosynth.csv"), + op.join(output_dir, f"{file_base_name}_pred-task_gclda.csv"), + ] + ) + if skip_existing and all(op.exists(path) for path in expected_outputs): + continue task_prob_df, concept_prob_df, process_prob_df = image_to_labels_hierarchical( img, - model_path, - vocabulary, - vocabulary_emb, - vocabulary_prior, - cognitiveatlas.concept_to_task_idxs, - cognitiveatlas.process_to_concept_idxs, - cognitiveatlas.concept_names, - cognitiveatlas.process_names, + resources["model_path"], + resources["vocabulary"], + resources["vocabulary_emb"], + resources["vocabulary_prior"], + cognitiveatlas, topk=topk, standardize=standardize, logit_scale=logit_scale, data_dir=data_dir, device=device, + model=model, + image_emb_gene=image_emb_gene, ) task_prob_df.to_csv(op.join(output_dir, task_out_fn), index=False) concept_prob_df.to_csv(op.join(output_dir, concept_out_fn), index=False) process_prob_df.to_csv(op.join(output_dir, process_out_fn), index=False) - if sub_category == "names": - baseline_label = f"{source}-{category}_embedding-{model_name}_section-{section}" - # Baseline - # -------------------------------------------------------------------- - ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" - gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" - - # Load baseline model - ns_model_fn = op.join( - results_dir, - "baseline", - f"model-neurosynth_{baseline_label}.pkl", - ) - gclda_model_fn = op.join( - results_dir, - "baseline", - f"model-gclda_{baseline_label}.pkl", - ) + if sub_category != "names": + continue - gclda_model = GCLDAModel.load(gclda_model_fn) - gclda_predictions_df, _ = gclda_decode_map(gclda_model, img) - gclda_predictions_df = gclda_predictions_df.sort_values( - by="Weight", ascending=False - ).head(topk) - gclda_predictions_df = gclda_predictions_df.reset_index() - gclda_predictions_df.columns = ["pred", "weight"] - gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) - - ns_decoder = CorrelationDecoder.load(ns_model_fn) - ns_predictions_df = ns_decoder.transform(img) - feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" - feature_names = ns_predictions_df.index.values - vocabulary_names = [f.replace(feature_group, "") for f in feature_names] - ns_predictions_df.index = vocabulary_names - - ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head( - topk - ) - ns_predictions_df = ns_predictions_df.reset_index() - ns_predictions_df.columns = ["pred", "corr"] - ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" + gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" + + gclda_img = _resample_for_reference(img, getattr(gclda_model, "mask", None)) + gclda_predictions_df, _ = gclda_decode_map(gclda_model, gclda_img) + gclda_predictions_df = gclda_predictions_df.sort_values(by="Weight", ascending=False).head( + topk + ) + gclda_predictions_df = gclda_predictions_df.reset_index() + gclda_predictions_df.columns = ["pred", "weight"] + gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) + + ns_img = _resample_for_reference(img, ns_mask) + ns_predictions_df = ns_decoder.transform(ns_img) + feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" + feature_names = ns_predictions_df.index.values + vocabulary_names = [f.replace(feature_group, "") for f in feature_names] + ns_predictions_df.index = vocabulary_names + ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head(topk) + ns_predictions_df = ns_predictions_df.reset_index() + ns_predictions_df.columns = ["pred", "corr"] + ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) if __name__ == "__main__": - main() + _main() diff --git a/jobs/decoding_seeds.py b/jobs/decoding_seeds.py index a4eed0e..f36c593 100644 --- a/jobs/decoding_seeds.py +++ b/jobs/decoding_seeds.py @@ -1,63 +1,103 @@ +import argparse import itertools import os import os.path as op from glob import glob -import nibabel as nib -import pandas as pd -from nilearn._utils.niimg_conversions import check_same_fov -from nilearn.image import resample_to_img -from nimare.annotate.gclda import GCLDAModel -from nimare.dataset import Dataset -from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map -from utils import _read_vocabulary - -from braindec.cogatlas import CognitiveAtlas -from braindec.plot import plot_vol_roi -from braindec.predict import image_to_labels_hierarchical - - -def main(): - project_dir = "/Users/julioaperaza/Documents/GitHub/brain-decoder" - project_dir = op.abspath(project_dir) - data_dir = op.join(project_dir, "data") - reduced = True +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + add_common_job_args, + build_cognitiveatlas, + get_source, + load_decoding_resources, + resolve_project_paths, +) + + +def _get_parser(): + parser = add_common_job_args(argparse.ArgumentParser(description="Decode ROI seed maps")) + parser.add_argument( + "--image_dir", + dest="image_dir", + default=None, + help="Optional explicit path to the ROI seed image directory.", + ) + parser.add_argument( + "--output_dir", + dest="output_dir", + default=None, + help="Optional explicit path to the output directory.", + ) + parser.add_argument( + "--categories", + dest="categories", + nargs="+", + default=["task"], + help="Vocabulary categories to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined"], + help="Vocabulary embedding variants to evaluate.", + ) + return parser + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + image_dir=None, + output_dir=None, + sections=None, + categories=None, + sub_categories=None, + model_ids=None, + topk=20, + standardize=False, + logit_scale=20.0, + device=None, + reduced=True, +): + import nibabel as nib + from nilearn.image import resample_to_img + from nimare.annotate.gclda import GCLDAModel + from nimare.dataset import Dataset + from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map + + from braindec.plot import plot_vol_roi + from braindec.predict import image_to_labels_hierarchical + from braindec.utils import _get_device, images_have_same_fov + + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + source = get_source(reduced) voc_dir = op.join(data_dir, "vocabulary") - source = "cogatlasred" if reduced else "cogatlas" - results_dir = op.join(project_dir, "results") - sections = ["abstract", "body"] - sub_categories = ["combined"] - categories = ["task"] - colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] - model_ids = [ - "BrainGPT/BrainGPT-7B-v0.2", - "mistralai/Mistral-7B-v0.1", - "BrainGPT/BrainGPT-7B-v0.1", - "meta-llama/Llama-2-7b-chat-hf", - ] - topk = 20 # top k predictions - standardize = False - logit_scale = 20 # None - device = "mps" - - output_dir = op.join(results_dir, "predictions_rois") + image_dir = op.join(data_dir, "seed-regions") if image_dir is None else op.abspath(image_dir) + output_dir = ( + op.join(results_dir, "predictions_rois") if output_dir is None else op.abspath(output_dir) + ) os.makedirs(output_dir, exist_ok=True) - dset = Dataset.load(op.join(data_dir, f"dset-pubmed_{source}-annotated_nimare.pkl")) - - reduced_tasks_fn = op.join(data_dir, "cognitive_atlas", "reduced_tasks.csv") - reduced_tasks_df = pd.read_csv(reduced_tasks_fn) if reduced else None - - cognitiveatlas = CognitiveAtlas( - data_dir=data_dir, - task_snapshot=op.join(data_dir, "cognitive_atlas", "task_snapshot-02-19-25.json"), - concept_snapshot=op.join( - data_dir, "cognitive_atlas", "concept_extended_snapshot-02-19-25.json" - ), - reduced_tasks=reduced_tasks_df, - ) + sections = list(DEFAULT_SECTIONS) if sections is None else sections + categories = ["task"] if categories is None else categories + sub_categories = ["combined"] if sub_categories is None else sub_categories + model_ids = list(DEFAULT_MODEL_IDS) if model_ids is None else model_ids + device = _get_device() if device is None else device + + colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (0, 255, 255), + (255, 0, 255), + ] - image_dir = op.join(data_dir, "seed-regions") + dset = Dataset.load(op.join(data_dir, f"dset-pubmed_{source}-annotated_nimare.pkl")) + cognitiveatlas = build_cognitiveatlas(data_dir, reduced) images = sorted(glob(op.join(image_dir, "*.nii.gz"))) for img_i, img_fn in enumerate(images): @@ -65,47 +105,41 @@ def main(): plot_vol_roi( img_fn, op.join(output_dir, f"{image_name}_map.png"), - color=colors[img_i], + color=colors[img_i % len(colors)], ) img = nib.load(img_fn) - if not check_same_fov(img, reference_masker=dset.masker.mask_img): + if not images_have_same_fov(img, dset.masker.mask_img): img = resample_to_img(img, dset.masker.mask_img) for section, category, sub_category, model_id in itertools.product( - sections, categories, sub_categories, model_ids + sections, + categories, + sub_categories, + model_ids, ): - model_name = model_id.split("/")[-1] - model_path = op.join( + resources = load_decoding_resources( results_dir, - "pubmed", - f"model-clip_section-{section}_embedding-{model_name}_best.pth", - ) - vocabulary_lb = f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}" - vocabulary_fn = op.join(voc_dir, f"vocabulary-{source}_{category}.txt") - vocabulary_emb_fn = op.join(voc_dir, f"{vocabulary_lb}.npy") - vocabulary_prior_fn = op.join(voc_dir, f"{vocabulary_lb}_section-{section}_prior.npy") - vocabulary, vocabulary_emb, vocabulary_prior = _read_vocabulary( - vocabulary_fn, - vocabulary_emb_fn, - vocabulary_prior_fn, + voc_dir, + source, + category, + sub_category, + model_id, + section, ) - file_base_name = f"{image_name}_{vocabulary_lb}_section-{section}" + file_base_name = f"{image_name}_{resources['vocabulary_label']}_section-{section}" task_out_fn = f"{file_base_name}_pred-task_brainclip.csv" concept_out_fn = f"{file_base_name}_pred-concept_brainclip.csv" process_out_fn = f"{file_base_name}_pred-process_brainclip.csv" task_prob_df, concept_prob_df, process_prob_df = image_to_labels_hierarchical( img, - model_path, - vocabulary, - vocabulary_emb, - vocabulary_prior, - cognitiveatlas.concept_to_task_idxs, - cognitiveatlas.process_to_concept_idxs, - cognitiveatlas.concept_names, - cognitiveatlas.process_names, + resources["model_path"], + resources["vocabulary"], + resources["vocabulary_emb"], + resources["vocabulary_prior"], + cognitiveatlas, topk=topk, standardize=standardize, logit_scale=logit_scale, @@ -116,48 +150,41 @@ def main(): concept_prob_df.to_csv(op.join(output_dir, concept_out_fn), index=False) process_prob_df.to_csv(op.join(output_dir, process_out_fn), index=False) - if sub_category == "names": - baseline_label = f"{source}-{category}_embedding-{model_name}_section-{section}" - # Baseline - # -------------------------------------------------------------------- - ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" - gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" - - # Load baseline model - ns_model_fn = op.join( - results_dir, - "baseline", - f"model-neurosynth_{baseline_label}.pkl", - ) - gclda_model_fn = op.join( - results_dir, - "baseline", - f"model-gclda_{baseline_label}.pkl", - ) - - gclda_model = GCLDAModel.load(gclda_model_fn) - gclda_predictions_df, _ = gclda_decode_map(gclda_model, img) - gclda_predictions_df = gclda_predictions_df.sort_values( - by="Weight", ascending=False - ).head(topk) - gclda_predictions_df = gclda_predictions_df.reset_index() - gclda_predictions_df.columns = ["pred", "weight"] - gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) - - ns_decoder = CorrelationDecoder.load(ns_model_fn) - ns_predictions_df = ns_decoder.transform(img) - feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" - feature_names = ns_predictions_df.index.values - vocabulary_names = [f.replace(feature_group, "") for f in feature_names] - ns_predictions_df.index = vocabulary_names - - ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head( - topk - ) - ns_predictions_df = ns_predictions_df.reset_index() - ns_predictions_df.columns = ["pred", "corr"] - ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + if sub_category != "names": + continue + + baseline_label = f"{source}-{category}_embedding-{resources['model_name']}_section-{section}" + ns_out_fn = f"{file_base_name}_pred-task_neurosynth.csv" + gclda_out_fn = f"{file_base_name}_pred-task_gclda.csv" + + ns_model_fn = op.join(results_dir, "baseline", f"model-neurosynth_{baseline_label}.pkl") + gclda_model_fn = op.join(results_dir, "baseline", f"model-gclda_{baseline_label}.pkl") + + gclda_model = GCLDAModel.load(gclda_model_fn) + gclda_predictions_df, _ = gclda_decode_map(gclda_model, img) + gclda_predictions_df = gclda_predictions_df.sort_values(by="Weight", ascending=False).head( + topk + ) + gclda_predictions_df = gclda_predictions_df.reset_index() + gclda_predictions_df.columns = ["pred", "weight"] + gclda_predictions_df.to_csv(op.join(output_dir, gclda_out_fn), index=False) + + ns_decoder = CorrelationDecoder.load(ns_model_fn) + ns_predictions_df = ns_decoder.transform(img) + feature_group = f"{source}-{category}_section-{section}_annot-tfidf__" + feature_names = ns_predictions_df.index.values + vocabulary_names = [f.replace(feature_group, "") for f in feature_names] + ns_predictions_df.index = vocabulary_names + ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head(topk) + ns_predictions_df = ns_predictions_df.reset_index() + ns_predictions_df.columns = ["pred", "corr"] + ns_predictions_df.to_csv(op.join(output_dir, ns_out_fn), index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) if __name__ == "__main__": - main() + _main() diff --git a/jobs/embedding_geometry.py b/jobs/embedding_geometry.py new file mode 100644 index 0000000..d8aa46c --- /dev/null +++ b/jobs/embedding_geometry.py @@ -0,0 +1,248 @@ +"""Project task and HCP image embeddings into 2D and summarize latent-space structure.""" + +import argparse +import os +import os.path as op + +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +import pandas as pd +import torch +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE + +from braindec.embedding import ImageEmbedding +from braindec.model import build_model +from jobs.utils import build_cognitiveatlas, get_model_name, resolve_project_paths + + +DEFAULT_MODEL_IDS = [ + "BrainGPT/BrainGPT-7B-v0.2", + "mistralai/Mistral-7B-v0.1", + "meta-llama/Llama-2-7b-chat-hf", +] + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Inspect shared text-image embedding geometry on HCP representative maps." + ) + parser.add_argument("--project_dir", dest="project_dir", default=None) + parser.add_argument("--data_dir", dest="data_dir", default=None) + parser.add_argument("--results_dir", dest="results_dir", default=None) + parser.add_argument( + "--hcp_mapping_fn", + dest="hcp_mapping_fn", + default=None, + help="Benchmark mapping CSV with representative HCP maps.", + ) + parser.add_argument( + "--model_ids", + dest="model_ids", + nargs="+", + default=list(DEFAULT_MODEL_IDS), + help="Models to compare.", + ) + parser.add_argument( + "--section", + dest="section", + default="body", + help="Model section to use.", + ) + parser.add_argument( + "--source", + dest="source", + default="cogatlasred", + help="Vocabulary source to use.", + ) + parser.add_argument("--coords_fn", dest="coords_fn", required=True) + parser.add_argument("--distance_fn", dest="distance_fn", required=True) + parser.add_argument("--plot_fn", dest="plot_fn", required=True) + parser.add_argument("--device", dest="device", default=None) + return parser + + +def _normalize_rows(array): + norms = np.linalg.norm(array, axis=1, keepdims=True) + 1e-8 + return array / norms + + +def _first_domain(atlas, task_name): + task_idx = atlas.get_task_idx_from_names(task_name) + process_idxs = atlas.task_to_process_idxs[task_idx] + if len(process_idxs) == 0: + return None + return atlas.get_process_names_from_idx(process_idxs)[0] + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + hcp_mapping_fn=None, + model_ids=None, + section="body", + source="cogatlasred", + coords_fn=None, + distance_fn=None, + plot_fn=None, + device=None, +): + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + hcp_mapping_fn = ( + op.join(".runs", "hcp_benchmark_best", "data", "hcp", "benchmark_mapping.csv") + if hcp_mapping_fn is None + else op.abspath(hcp_mapping_fn) + ) + hcp_mapping_fn = op.abspath(hcp_mapping_fn) + atlas = build_cognitiveatlas(data_dir, reduced=(source == "cogatlasred")) + hcp_df = pd.read_csv(hcp_mapping_fn) + + image_emb_gene = ImageEmbedding(standardize=False, nilearn_dir=op.join(data_dir, "nilearn")) + hcp_images = [nib.load(path) for path in hcp_df["image_path"].tolist()] + image_inputs = image_emb_gene(hcp_images) + image_inputs = _normalize_rows(image_inputs) + + coords_rows = [] + distance_rows = [] + for model_id in model_ids: + model_name = get_model_name(model_id) + model_path = op.join( + results_dir, + "pubmed", + f"model-clip_section-{section}_embedding-{model_name}_best.pth", + ) + vocab_emb = np.load( + op.join( + data_dir, + "vocabulary", + f"vocabulary-{source}_task-combined_embedding-{model_name}.npy", + ) + ) + with open(op.join(data_dir, "vocabulary", f"vocabulary-{source}_task.txt"), "r") as file_obj: + task_vocab = [line.strip() for line in file_obj] + + model = build_model(model_path, device=device or "cpu") + with torch.no_grad(): + text_features = ( + model.encode_text(torch.from_numpy(vocab_emb).float().to(model.device)) + .cpu() + .numpy() + ) + image_features = ( + model.encode_image(torch.from_numpy(image_inputs).float().to(model.device)) + .cpu() + .numpy() + ) + text_features = _normalize_rows(text_features) + image_features = _normalize_rows(image_features) + + domain_labels = [_first_domain(atlas, task_name) for task_name in task_vocab] + combined = np.vstack([text_features, image_features]) + pca = PCA(n_components=min(20, combined.shape[1], combined.shape[0] - 1), random_state=0) + reduced = pca.fit_transform(combined) + perplexity = max(2, min(15, combined.shape[0] - 1)) + coords = TSNE( + n_components=2, + init="pca", + learning_rate="auto", + perplexity=perplexity, + random_state=0, + ).fit_transform(reduced) + + text_coords = coords[: len(task_vocab)] + image_coords = coords[len(task_vocab) :] + for idx, task_name in enumerate(task_vocab): + coords_rows.append( + { + "model_name": model_name, + "point_type": "task_text", + "label": task_name, + "group": domain_labels[idx], + "x": float(text_coords[idx, 0]), + "y": float(text_coords[idx, 1]), + } + ) + for idx, row in hcp_df.iterrows(): + coords_rows.append( + { + "model_name": model_name, + "point_type": "hcp_image", + "label": row["task_name"], + "group": row["domain_key"], + "x": float(image_coords[idx, 0]), + "y": float(image_coords[idx, 1]), + } + ) + + task_to_feature = {task_name: text_features[idx] for idx, task_name in enumerate(task_vocab)} + cosine = text_features @ text_features.T + same_domain = [] + different_domain = [] + for i in range(len(task_vocab)): + for j in range(i + 1, len(task_vocab)): + if domain_labels[i] is None or domain_labels[j] is None: + continue + if domain_labels[i] == domain_labels[j]: + same_domain.append(cosine[i, j]) + else: + different_domain.append(cosine[i, j]) + matched_cosines = [] + for idx, row in hcp_df.iterrows(): + if row["task_name"] not in task_to_feature: + continue + matched_cosines.append(float(np.dot(image_features[idx], task_to_feature[row["task_name"]]))) + distance_rows.extend( + [ + { + "model_name": model_name, + "metric": "mean_within_domain_task_cosine", + "value": float(np.mean(same_domain)), + }, + { + "model_name": model_name, + "metric": "mean_between_domain_task_cosine", + "value": float(np.mean(different_domain)), + }, + { + "model_name": model_name, + "metric": "mean_hcp_image_to_matched_task_cosine", + "value": float(np.mean(matched_cosines)), + }, + ] + ) + + coords_df = pd.DataFrame(coords_rows) + distance_df = pd.DataFrame(distance_rows) + os.makedirs(op.dirname(op.abspath(coords_fn)), exist_ok=True) + coords_df.to_csv(coords_fn, index=False) + distance_df.to_csv(distance_fn, index=False) + + model_names = [get_model_name(model_id) for model_id in model_ids] + fig, axes = plt.subplots(1, len(model_names), figsize=(6 * len(model_names), 5), squeeze=False) + for ax, model_name in zip(axes[0], model_names): + sub = coords_df.loc[coords_df["model_name"] == model_name] + text_sub = sub.loc[sub["point_type"] == "task_text"] + image_sub = sub.loc[sub["point_type"] == "hcp_image"] + ax.scatter(text_sub["x"], text_sub["y"], s=12, alpha=0.35, label="Task text") + ax.scatter(image_sub["x"], image_sub["y"], s=60, marker="x", label="HCP image") + for row in image_sub.itertuples(index=False): + ax.text(row.x, row.y, row.label, fontsize=7) + ax.set_title(model_name) + ax.set_xlabel("Dim 1") + ax.set_ylabel("Dim 2") + ax.legend(loc="best") + fig.tight_layout() + os.makedirs(op.dirname(op.abspath(plot_fn)), exist_ok=True) + fig.savefig(plot_fn, dpi=200, bbox_inches="tight") + plt.close(fig) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/emotion_analysis.py b/jobs/emotion_analysis.py new file mode 100644 index 0000000..4fa5d26 --- /dev/null +++ b/jobs/emotion_analysis.py @@ -0,0 +1,144 @@ +"""Audit emotion-task coverage and summarize emotion-map decoding performance.""" + +import argparse +import json +import os +import os.path as op + +import pandas as pd + + +EMOTION_KEYWORDS = [ + "emotion", + "emotional", + "fear", + "anger", + "disgust", + "happiness", + "happy", + "sad", + "face", +] + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Summarize fine-grained emotion coverage and decoding on available benchmark maps." + ) + parser.add_argument("--ibc_mapping_fn", dest="ibc_mapping_fn", required=True) + parser.add_argument("--details_fn", dest="details_fn", required=True) + parser.add_argument("--full_vocab_fn", dest="full_vocab_fn", required=True) + parser.add_argument("--reduced_vocab_fn", dest="reduced_vocab_fn", required=True) + parser.add_argument("--concept_snapshot_fn", dest="concept_snapshot_fn", required=True) + parser.add_argument("--coverage_fn", dest="coverage_fn", required=True) + parser.add_argument("--summary_fn", dest="summary_fn", required=True) + return parser + + +def _contains_emotion(text): + text = str(text).lower() + return any(keyword in text for keyword in EMOTION_KEYWORDS) + + +def _parse_json_list(value): + if pd.isna(value): + return [] + return list(json.loads(value)) + + +def _topk_emotion_hits(top_predictions_json): + predictions = _parse_json_list(top_predictions_json) + return [pred for pred in predictions if _contains_emotion(pred)] + + +def main( + ibc_mapping_fn, + details_fn, + full_vocab_fn, + reduced_vocab_fn, + concept_snapshot_fn, + coverage_fn, + summary_fn, +): + ibc_mapping_df = pd.read_csv(op.abspath(ibc_mapping_fn)) + details_df = pd.read_csv(op.abspath(details_fn)) + + emotion_maps_df = ibc_mapping_df.loc[ + ibc_mapping_df["task_family"].astype(str).str.contains("emotion|emotional", case=False, na=False) + | ibc_mapping_df["task_name"].astype(str).map(_contains_emotion) + | ibc_mapping_df["contrast_definition"].astype(str).map(_contains_emotion) + ].copy() + emotion_labels = set(emotion_maps_df["prediction_label"].astype(str)) + + with open(op.abspath(full_vocab_fn), "r") as file_obj: + full_vocab = [line.strip() for line in file_obj] + with open(op.abspath(reduced_vocab_fn), "r") as file_obj: + reduced_vocab = [line.strip() for line in file_obj] + with open(op.abspath(concept_snapshot_fn), "r") as file_obj: + concepts = json.load(file_obj) + + concept_names = [ + concept.get("name", "") + for concept in concepts + if concept.get("name") and concept.get("definition_text") + ] + coverage_rows = [] + for source_name, terms in [ + ("full_task_vocabulary", full_vocab), + ("reduced_task_vocabulary", reduced_vocab), + ("concept_snapshot", concept_names), + ]: + hits = [term for term in terms if _contains_emotion(term)] + for term in hits: + coverage_rows.append({"source": source_name, "term": term}) + + coverage_df = pd.DataFrame(coverage_rows) + os.makedirs(op.dirname(op.abspath(coverage_fn)), exist_ok=True) + coverage_df.to_csv(coverage_fn, index=False) + + emotion_details_df = details_df.loc[ + (details_df["dataset"].astype(str).isin(["ibc", "ibc_full"])) + & (details_df["backend"] == "brainclip") + & (details_df["prediction_label"].astype(str).isin(emotion_labels)) + ].copy() + emotion_details_df["n_emotion_predictions_top20"] = emotion_details_df["top_predictions_json"].map( + lambda value: len(_topk_emotion_hits(value)) + ) + emotion_details_df["top_emotion_predictions_json"] = emotion_details_df["top_predictions_json"].map( + lambda value: json.dumps(_topk_emotion_hits(value)) + ) + for column in ["recall_at_k", "hit_at_k", "best_rank", "n_emotion_predictions_top20"]: + emotion_details_df[column] = pd.to_numeric(emotion_details_df[column], errors="coerce") + + grouped = emotion_details_df.groupby( + ["sub_category", "level"], + dropna=False, + sort=False, + ) + summary_df = ( + grouped.agg( + n_images=("prediction_label", "size"), + mean_recall_at_k=("recall_at_k", "mean"), + mean_hit_at_k=("hit_at_k", "mean"), + median_best_rank=("best_rank", "median"), + mean_n_emotion_predictions_top20=("n_emotion_predictions_top20", "mean"), + ) + .reset_index() + ) + summary_df.insert(0, "n_emotion_maps", len(emotion_labels)) + summary_df.insert( + 1, + "ground_truth_granularity_note", + "Available IBC emotion maps are coarse face/shape or emotional localizer contrasts rather than specific emotion labels.", + ) + os.makedirs(op.dirname(op.abspath(summary_fn)), exist_ok=True) + summary_df.to_csv(summary_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/nsd_pilot.py b/jobs/nsd_pilot.py new file mode 100644 index 0000000..76d2eae --- /dev/null +++ b/jobs/nsd_pilot.py @@ -0,0 +1,57 @@ +"""Run or document the optional NSD pilot status.""" + +import argparse +import os +import os.path as op + +import pandas as pd + +from jobs.utils import resolve_project_paths + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Optional NSD pilot status report.") + parser.add_argument("--project_dir", dest="project_dir", default=None) + parser.add_argument("--data_dir", dest="data_dir", default=None) + parser.add_argument("--results_dir", dest="results_dir", default=None) + parser.add_argument("--nsd_dir", dest="nsd_dir", default=None) + parser.add_argument("--output_fn", dest="output_fn", required=True) + return parser + + +def main(project_dir=None, data_dir=None, results_dir=None, nsd_dir=None, output_fn=None): + _, data_dir, _ = resolve_project_paths(project_dir, data_dir, results_dir) + nsd_dir = op.join(data_dir, "nsd") if nsd_dir is None else op.abspath(nsd_dir) + + if not op.isdir(nsd_dir): + report_df = pd.DataFrame( + [ + { + "status": "blocked", + "reason": "nsd_maps_missing", + "expected_path": nsd_dir, + } + ] + ) + else: + report_df = pd.DataFrame( + [ + { + "status": "ready", + "reason": "nsd_maps_present", + "expected_path": nsd_dir, + } + ] + ) + + os.makedirs(op.dirname(op.abspath(output_fn)), exist_ok=True) + report_df.to_csv(output_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/ontology_comparison.py b/jobs/ontology_comparison.py new file mode 100644 index 0000000..7010232 --- /dev/null +++ b/jobs/ontology_comparison.py @@ -0,0 +1,198 @@ +"""Compare reduced versus full Cognitive Atlas decoding performance.""" + +import argparse +import json +import os +import os.path as op + +import numpy as np +import pandas as pd + +from jobs.utils import build_cognitiveatlas, resolve_project_paths + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Summarize reduced versus full ontology structure and decoding performance." + ) + parser.add_argument("--project_dir", dest="project_dir", default=None) + parser.add_argument("--data_dir", dest="data_dir", default=None) + parser.add_argument("--results_dir", dest="results_dir", default=None) + parser.add_argument("--reduced_eval_fn", dest="reduced_eval_fn", required=True) + parser.add_argument("--reduced_details_fn", dest="reduced_details_fn", required=True) + parser.add_argument("--full_eval_fn", dest="full_eval_fn", required=True) + parser.add_argument("--full_details_fn", dest="full_details_fn", required=True) + parser.add_argument("--full_mapping_fn", dest="full_mapping_fn", required=True) + parser.add_argument("--dataset_name", dest="dataset_name", required=True) + parser.add_argument("--ontology_stats_fn", dest="ontology_stats_fn", required=True) + parser.add_argument("--comparison_fn", dest="comparison_fn", required=True) + parser.add_argument( + "--term_table_fn", + dest="term_table_fn", + default=None, + help="Optional path for a task-level retained/removed/enriched ontology table.", + ) + return parser + + +def _ontology_stats(data_dir): + rows = [] + for label, reduced in [("reduced", True), ("full", False)]: + atlas = build_cognitiveatlas(data_dir, reduced=reduced) + rows.append( + { + "ontology": label, + "n_tasks": int(len(atlas.task_names)), + "n_concepts": int(len(atlas.concept_names)), + "n_domains": int(len(atlas.process_names)), + "n_task_concept_edges": int(sum(len(idxs) for idxs in atlas.task_to_concept_idxs)), + "n_concept_domain_edges": int(sum(len(idxs) for idxs in atlas.process_to_concept_idxs)), + } + ) + return pd.DataFrame(rows) + + +def _aggregate_from_details(details_df): + grouped = details_df.groupby( + [ + "dataset", + "source", + "category", + "sub_category", + "section", + "model_id", + "model_name", + "vocabulary_label", + "backend", + "level", + "k", + ], + dropna=False, + sort=False, + ) + return ( + grouped.agg( + n_images=("prediction_label", "size"), + mean_recall_at_k=("recall_at_k", "mean"), + mean_hit_at_k=("hit_at_k", "mean"), + median_best_rank=("best_rank", lambda x: np.nanmedian(np.where(np.isfinite(x), x, np.nan))), + ) + .reset_index() + ) + + +def _task_term_table(data_dir): + full_atlas = build_cognitiveatlas(data_dir, reduced=False) + reduced_atlas = build_cognitiveatlas(data_dir, reduced=True) + + full_task_to_concepts = { + task_name: {full_atlas.concept_names[idx] for idx in concept_idxs} + for task_name, concept_idxs in zip(full_atlas.task_names, full_atlas.task_to_concept_idxs) + } + reduced_task_to_concepts = { + task_name: {reduced_atlas.concept_names[idx] for idx in concept_idxs} + for task_name, concept_idxs in zip(reduced_atlas.task_names, reduced_atlas.task_to_concept_idxs) + } + + rows = [] + all_tasks = sorted(set(full_task_to_concepts) | set(reduced_task_to_concepts)) + for task_name in all_tasks: + full_concepts = full_task_to_concepts.get(task_name, set()) + reduced_concepts = reduced_task_to_concepts.get(task_name, set()) + shared_concepts = full_concepts & reduced_concepts + + if task_name not in reduced_task_to_concepts: + status = "removed_in_reduced" + elif full_concepts == reduced_concepts: + status = "retained_same_concepts" + elif reduced_concepts > full_concepts: + status = "retained_enriched_concepts" + elif reduced_concepts < full_concepts: + status = "retained_pruned_concepts" + else: + status = "retained_rewired_concepts" + + rows.append( + { + "task_name": task_name, + "in_full_ontology": task_name in full_task_to_concepts, + "in_reduced_ontology": task_name in reduced_task_to_concepts, + "status": status, + "n_full_concepts": len(full_concepts), + "n_reduced_concepts": len(reduced_concepts), + "n_shared_concepts": len(shared_concepts), + "full_concepts_json": json.dumps(sorted(full_concepts)), + "reduced_concepts_json": json.dumps(sorted(reduced_concepts)), + } + ) + + return pd.DataFrame(rows) + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + reduced_eval_fn=None, + reduced_details_fn=None, + full_eval_fn=None, + full_details_fn=None, + full_mapping_fn=None, + dataset_name=None, + ontology_stats_fn=None, + comparison_fn=None, + term_table_fn=None, +): + _, data_dir, _ = resolve_project_paths(project_dir, data_dir, results_dir) + reduced_eval_df = pd.read_csv(op.abspath(reduced_eval_fn)) + reduced_details_df = pd.read_csv(op.abspath(reduced_details_fn)) + full_eval_df = pd.read_csv(op.abspath(full_eval_fn)) + full_details_df = pd.read_csv(op.abspath(full_details_fn)) + full_mapping_df = pd.read_csv(op.abspath(full_mapping_fn)) + + stats_df = _ontology_stats(data_dir) + os.makedirs(op.dirname(op.abspath(ontology_stats_fn)), exist_ok=True) + stats_df.to_csv(ontology_stats_fn, index=False) + + if term_table_fn is not None: + term_table_df = _task_term_table(data_dir) + os.makedirs(op.dirname(op.abspath(term_table_fn)), exist_ok=True) + term_table_df.to_csv(term_table_fn, index=False) + + reduced_eval_df = reduced_eval_df.assign(ontology="reduced", evaluation_scope="all_images") + full_eval_df = full_eval_df.assign(ontology="full", evaluation_scope="all_images") + + overlap_labels = set( + full_mapping_df.loc[full_mapping_df["in_reduced_ontology"].fillna(False), "prediction_label"].astype(str) + ) + full_overlap_df = full_details_df.loc[full_details_df["prediction_label"].astype(str).isin(overlap_labels)].copy() + full_overlap_summary_df = _aggregate_from_details(full_overlap_df).assign( + ontology="full", + evaluation_scope="overlap_with_reduced", + ) + reduced_overlap_df = reduced_details_df.loc[ + reduced_details_df["prediction_label"].astype(str).isin(overlap_labels) + ].copy() + reduced_overlap_summary_df = _aggregate_from_details(reduced_overlap_df).assign( + ontology="reduced", + evaluation_scope="overlap_with_reduced", + ) + + comparison_df = pd.concat( + [reduced_eval_df, full_eval_df, reduced_overlap_summary_df, full_overlap_summary_df], + ignore_index=True, + sort=False, + ) + comparison_df.insert(0, "dataset_name", dataset_name) + + os.makedirs(op.dirname(op.abspath(comparison_fn)), exist_ok=True) + comparison_df.to_csv(comparison_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/per_term_eval.py b/jobs/per_term_eval.py new file mode 100644 index 0000000..e2380eb --- /dev/null +++ b/jobs/per_term_eval.py @@ -0,0 +1,260 @@ +"""Compute per-term decoding performance with a permutation null baseline.""" + +import argparse +import json +import os +import os.path as op + +import numpy as np +import pandas as pd +from tqdm.auto import tqdm + + +GROUP_COLUMNS = [ + "dataset", + "source", + "category", + "sub_category", + "section", + "model_id", + "model_name", + "vocabulary_label", + "backend", + "level", + "k", +] + + +def _parse_json_list(value): + if isinstance(value, list): + return value + if pd.isna(value): + return [] + return list(json.loads(value)) + + +def _term_rank(term, predictions): + for rank, pred in enumerate(predictions, start=1): + if pred == term: + return rank + return np.inf + + +def _make_term_observations(details_df): + rows = [] + for row in details_df.itertuples(index=False): + true_labels = _parse_json_list(row.true_labels_json) + predictions = _parse_json_list(row.top_predictions_json) + for term in true_labels: + rank = _term_rank(term, predictions) + rows.append( + { + "dataset": row.dataset, + "source": row.source, + "category": row.category, + "sub_category": row.sub_category, + "section": row.section, + "model_id": row.model_id, + "model_name": row.model_name, + "vocabulary_label": row.vocabulary_label, + "backend": row.backend, + "level": row.level, + "k": row.k, + "term": term, + "rank": rank, + "hit_at_k": float(rank <= row.k), + } + ) + return pd.DataFrame(rows) + + +def _aggregate_term_rows(term_df): + grouped = term_df.groupby(GROUP_COLUMNS + ["term"], dropna=False, sort=False) + return ( + grouped.agg( + n_images=("term", "size"), + mean_hit_at_k=("hit_at_k", "mean"), + mean_rank=("rank", lambda x: np.nanmean(np.where(np.isfinite(x), x, np.nan))), + median_rank=("rank", lambda x: np.nanmedian(np.where(np.isfinite(x), x, np.nan))), + mean_reciprocal_rank=("rank", lambda x: np.mean(np.where(np.isfinite(x), 1.0 / x, 0.0))), + ) + .reset_index() + ) + + +def _permutation_null(details_df, actual_df, n_permutations, random_seed): + rng = np.random.default_rng(random_seed) + actual_keyed = actual_df.set_index(GROUP_COLUMNS + ["term"]) + accum = {} + + grouped = details_df.groupby(GROUP_COLUMNS, dropna=False, sort=False) + for group_key, group_df in tqdm(grouped, total=grouped.ngroups, desc="permutation groups"): + group_rows = group_df.copy() + prediction_lists = [ + _parse_json_list(value) for value in group_rows["top_predictions_json"].tolist() + ] + true_lists = [_parse_json_list(value) for value in group_rows["true_labels_json"].tolist()] + k_value = int(group_rows["k"].iloc[0]) + term_counts = {} + for true_terms in true_lists: + for term in true_terms: + term_counts[term] = term_counts.get(term, 0) + 1 + + observed_hits = { + term: actual_keyed.loc[group_key + (term,), "mean_hit_at_k"] for term in term_counts + } + + sum_means = {term: 0.0 for term in term_counts} + sum_sq_means = {term: 0.0 for term in term_counts} + ge_counts = {term: 0 for term in term_counts} + + for _ in range(n_permutations): + permuted_indices = rng.permutation(len(prediction_lists)) + hit_sums = {term: 0.0 for term in term_counts} + + for row_idx, true_terms in enumerate(true_lists): + permuted_predictions = prediction_lists[permuted_indices[row_idx]] + topk_predictions = set(permuted_predictions[:k_value]) + for term in true_terms: + hit_sums[term] += float(term in topk_predictions) + + for term, count in term_counts.items(): + perm_mean = hit_sums[term] / count + sum_means[term] += perm_mean + sum_sq_means[term] += perm_mean**2 + ge_counts[term] += int(perm_mean >= observed_hits[term]) + + for term, count in term_counts.items(): + key = group_key + (term,) + null_mean = sum_means[term] / n_permutations + variance = max(sum_sq_means[term] / n_permutations - null_mean**2, 0.0) + accum[key] = { + "null_mean_hit_at_k": null_mean, + "null_std_hit_at_k": float(np.sqrt(variance)), + "empirical_p_value": (ge_counts[term] + 1) / (n_permutations + 1), + "n_null_permutations": n_permutations, + } + + return pd.DataFrame( + [ + dict(zip(GROUP_COLUMNS + ["term"], key), **value) + for key, value in accum.items() + ] + ) + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Aggregate detailed decoding evaluations into per-term scores with a null baseline." + ) + parser.add_argument( + "--details_fns", + dest="details_fns", + nargs="+", + required=True, + help="One or more detailed evaluation CSVs emitted by jobs/decoding_eval.py.", + ) + parser.add_argument( + "--levels", + dest="levels", + nargs="+", + default=["task", "concept", "domain"], + help="Prediction levels to analyze.", + ) + parser.add_argument( + "--backends", + dest="backends", + nargs="+", + default=None, + help="Optional backend filter, e.g. brainclip or gclda.", + ) + parser.add_argument( + "--n_permutations", + dest="n_permutations", + type=int, + default=1000, + help="Number of within-group permutations for the null baseline.", + ) + parser.add_argument( + "--random_seed", + dest="random_seed", + type=int, + default=0, + help="Random seed for permutation generation.", + ) + parser.add_argument( + "--alpha", + dest="alpha", + type=float, + default=0.05, + help="Significance threshold used for the above-chance summary flag.", + ) + parser.add_argument( + "--output_fn", + dest="output_fn", + required=True, + help="Path to the per-term output CSV.", + ) + parser.add_argument( + "--summary_output_fn", + dest="summary_output_fn", + default=None, + help="Optional path to a group-level summary CSV.", + ) + return parser + + +def main( + details_fns, + levels=None, + backends=None, + n_permutations=1000, + random_seed=0, + alpha=0.05, + output_fn=None, + summary_output_fn=None, +): + levels = ["task", "concept", "domain"] if levels is None else levels + details_df = pd.concat([pd.read_csv(path) for path in details_fns], ignore_index=True) + details_df = details_df.loc[details_df["level"].isin(levels)].reset_index(drop=True) + if backends is not None: + details_df = details_df.loc[details_df["backend"].isin(backends)].reset_index(drop=True) + + term_df = _make_term_observations(details_df) + actual_df = _aggregate_term_rows(term_df) + null_df = _permutation_null(details_df, actual_df, n_permutations, random_seed) + result_df = actual_df.merge(null_df, on=GROUP_COLUMNS + ["term"], how="left") + result_df["normalized_hit_at_k"] = result_df["mean_hit_at_k"] / result_df["null_mean_hit_at_k"].replace( + 0, + np.nan, + ) + result_df["above_chance"] = ( + (result_df["mean_hit_at_k"] > result_df["null_mean_hit_at_k"]) + & (result_df["empirical_p_value"] <= alpha) + ) + + os.makedirs(op.dirname(op.abspath(output_fn)), exist_ok=True) + result_df.to_csv(output_fn, index=False) + + if summary_output_fn is not None: + summary_df = ( + result_df.groupby(GROUP_COLUMNS, dropna=False, sort=False) + .agg( + n_terms=("term", "size"), + n_terms_above_chance=("above_chance", "sum"), + mean_term_hit_at_k=("mean_hit_at_k", "mean"), + mean_term_null_hit_at_k=("null_mean_hit_at_k", "mean"), + ) + .reset_index() + ) + os.makedirs(op.dirname(op.abspath(summary_output_fn)), exist_ok=True) + summary_df.to_csv(summary_output_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/per_term_factors.py b/jobs/per_term_factors.py new file mode 100644 index 0000000..f745448 --- /dev/null +++ b/jobs/per_term_factors.py @@ -0,0 +1,258 @@ +"""Analyze which task-level features are associated with per-term decoding performance.""" + +import argparse +import os +import os.path as op + +import nibabel as nib +import numpy as np +import pandas as pd +from scipy.stats import spearmanr +from tqdm.auto import tqdm + +from jobs.utils import build_cognitiveatlas, resolve_project_paths + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Correlate per-term decoding performance with simple task-level features." + ) + parser.add_argument( + "--project_dir", + dest="project_dir", + default=None, + help="Path to the repository root.", + ) + parser.add_argument( + "--data_dir", + dest="data_dir", + default=None, + help="Optional explicit data directory.", + ) + parser.add_argument( + "--results_dir", + dest="results_dir", + default=None, + help="Optional explicit results directory.", + ) + parser.add_argument( + "--per_term_fn", + dest="per_term_fn", + default=None, + help="Per-term CSV from jobs/per_term_eval.py.", + ) + parser.add_argument( + "--ibc_mapping_fn", + dest="ibc_mapping_fn", + default=None, + help="IBC reduced mapping CSV.", + ) + parser.add_argument( + "--cnp_mapping_fn", + dest="cnp_mapping_fn", + default=None, + help="CNP reduced mapping CSV.", + ) + parser.add_argument( + "--output_table_fn", + dest="output_table_fn", + required=True, + help="Path to the per-term feature table CSV.", + ) + parser.add_argument( + "--output_summary_fn", + dest="output_summary_fn", + required=True, + help="Path to the correlation/regression summary CSV.", + ) + return parser + + +def _safe_zscore(values): + values = np.asarray(values, dtype=float) + std = np.nanstd(values) + if not np.isfinite(std) or std == 0: + return np.zeros_like(values) + return (values - np.nanmean(values)) / std + + +def _term_specificity(image_paths): + specificities = [] + for image_path in image_paths: + data = np.asanyarray(nib.load(image_path).dataobj, dtype=np.float32) + data = np.abs(data[np.isfinite(data)]) + if data.size == 0: + continue + l1 = float(data.sum()) + l2_sq = float((data**2).sum()) + if l1 == 0 or l2_sq == 0: + continue + effective_voxels = (l1**2) / l2_sq + specificity = 1.0 - (effective_voxels / data.size) + specificities.append(specificity) + if not specificities: + return np.nan + return float(np.mean(specificities)) + + +def _load_task_feature_frame(data_dir, ibc_mapping_fn, cnp_mapping_fn): + cognitiveatlas = build_cognitiveatlas(data_dir, reduced=True) + task_df = pd.DataFrame( + { + "term": cognitiveatlas.task_names, + "definition_length_chars": [len(text) for text in cognitiveatlas.task_definitions], + "definition_length_words": [len(text.split()) for text in cognitiveatlas.task_definitions], + "n_linked_concepts": [len(idxs) for idxs in cognitiveatlas.task_to_concept_idxs], + "n_linked_domains": [len(idxs) for idxs in cognitiveatlas.task_to_process_idxs], + } + ) + + counts = np.load(op.join(data_dir, "vocabulary", "vocabulary-cogatlasred_task-names_section-body_counts.npy")) + if counts.ndim == 2: + counts = counts.sum(axis=1) + emb = np.load( + op.join( + data_dir, + "vocabulary", + "vocabulary-cogatlasred_task-combined_embedding-BrainGPT-7B-v0.2.npy", + ) + ) + task_df["training_article_count"] = counts + task_df["embedding_norm"] = np.linalg.norm(emb, axis=1) + + mapping_frames = [] + for dataset, mapping_fn in [ + ("ibc", ibc_mapping_fn), + ("cnp", cnp_mapping_fn), + ]: + if mapping_fn is None: + continue + mapping_df = pd.read_csv(mapping_fn) + mapping_df = mapping_df.loc[:, ["task_name", "local_path"]].copy() + mapping_df["dataset"] = dataset + mapping_frames.append(mapping_df.rename(columns={"task_name": "term"})) + mappings_df = pd.concat(mapping_frames, ignore_index=True) + + specificity_rows = [] + grouped = mappings_df.groupby("term", dropna=False, sort=True) + for term, group_df in tqdm(grouped, total=grouped.ngroups, desc="term specificity"): + specificity_rows.append( + { + "term": term, + "mean_map_specificity": _term_specificity(group_df["local_path"].tolist()), + "n_eval_maps": int(len(group_df)), + "n_eval_datasets": int(group_df["dataset"].nunique()), + } + ) + + specificity_df = pd.DataFrame(specificity_rows) + return task_df.merge(specificity_df, on="term", how="left") + + +def _fit_standardized_regression(df, predictors, target): + work_df = df.loc[:, predictors + [target]].dropna().copy() + if len(work_df) < 3: + return [] + x = np.column_stack([_safe_zscore(work_df[col]) for col in predictors] + [np.ones(len(work_df))]) + y = _safe_zscore(work_df[target].to_numpy()) + coef, _, _, _ = np.linalg.lstsq(x, y, rcond=None) + y_hat = x @ coef + ss_res = float(((y - y_hat) ** 2).sum()) + ss_tot = float(((y - y.mean()) ** 2).sum()) + r2 = np.nan if ss_tot == 0 else 1.0 - ss_res / ss_tot + rows = [] + for predictor, beta in zip(predictors, coef[:-1]): + rows.append( + { + "analysis": "standardized_regression", + "target": target, + "feature": predictor, + "value": float(beta), + "n_terms": int(len(work_df)), + "r2": r2, + } + ) + return rows + + +def main( + project_dir=None, + data_dir=None, + results_dir=None, + per_term_fn=None, + ibc_mapping_fn=None, + cnp_mapping_fn=None, + output_table_fn=None, + output_summary_fn=None, +): + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + per_term_fn = ( + op.join(results_dir, "per_term_cross_dataset_reduced_full.csv") + if per_term_fn is None + else op.abspath(per_term_fn) + ) + ibc_mapping_fn = ( + op.join(data_dir, "ibc", "mapping_reduced.csv") + if ibc_mapping_fn is None + else op.abspath(ibc_mapping_fn) + ) + cnp_mapping_fn = ( + op.join(data_dir, "cnp", "mapping_reduced.csv") + if cnp_mapping_fn is None + else op.abspath(cnp_mapping_fn) + ) + + per_term_df = pd.read_csv(per_term_fn) + per_term_df = per_term_df.loc[ + (per_term_df["backend"] == "brainclip") + & (per_term_df["level"] == "task") + & (per_term_df["sub_category"] == "combined") + & (per_term_df["section"] == "body") + ].copy() + + feature_df = _load_task_feature_frame(data_dir, ibc_mapping_fn, cnp_mapping_fn) + result_df = per_term_df.merge(feature_df, on="term", how="left") + os.makedirs(op.dirname(op.abspath(output_table_fn)), exist_ok=True) + result_df.to_csv(output_table_fn, index=False) + + feature_columns = [ + "training_article_count", + "definition_length_words", + "embedding_norm", + "n_linked_concepts", + "n_linked_domains", + "mean_map_specificity", + "n_eval_maps", + ] + summary_rows = [] + for target in ["mean_hit_at_k", "mean_reciprocal_rank"]: + for feature in feature_columns: + work_df = result_df.loc[:, [feature, target]].dropna() + if len(work_df) < 3: + continue + rho, p_value = spearmanr(work_df[feature], work_df[target]) + summary_rows.append( + { + "analysis": "spearman", + "target": target, + "feature": feature, + "value": float(rho), + "p_value": float(p_value), + "n_terms": int(len(work_df)), + "r2": np.nan, + } + ) + summary_rows.extend(_fit_standardized_regression(result_df, feature_columns, target)) + + summary_df = pd.DataFrame(summary_rows) + os.makedirs(op.dirname(op.abspath(output_summary_fn)), exist_ok=True) + summary_df.to_csv(output_summary_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/prepare_neurovault_data.py b/jobs/prepare_neurovault_data.py new file mode 100644 index 0000000..ab3c1a4 --- /dev/null +++ b/jobs/prepare_neurovault_data.py @@ -0,0 +1,274 @@ +"""Download and normalize NeuroVault datasets used for cross-dataset decoding.""" + +import argparse +import json +import os +import os.path as op +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd +import requests +from tqdm.auto import tqdm + +IBC_COLLECTIONS = [2138, 6618] +CNP_COLLECTIONS = [2606] +DEFAULT_TIMEOUT = 120 + + +def _slugify(value): + value = re.sub(r"[^0-9A-Za-z]+", "-", str(value).strip()) + value = re.sub(r"-{2,}", "-", value).strip("-") + return value or "image" + + +def _request_json(url, timeout=DEFAULT_TIMEOUT): + response = requests.get(url, timeout=timeout) + response.raise_for_status() + return response.json() + + +def _iter_collection_images(collection_id, timeout=DEFAULT_TIMEOUT): + url = f"https://neurovault.org/api/collections/{collection_id}/images/" + while url: + payload = _request_json(url, timeout=timeout) + for row in payload["results"]: + yield row + url = payload.get("next") + + +def _fetch_collection_metadata(collection_id, timeout=DEFAULT_TIMEOUT): + return _request_json(f"https://neurovault.org/api/collections/{collection_id}/", timeout=timeout) + + +def _infer_dataset_task_family(dataset_name, row): + if dataset_name == "cnp": + name = (row.get("name") or "").strip() + return name.split()[0].upper() if name else "UNKNOWN" + if dataset_name == "ibc": + task = row.get("task") or "" + return task.strip() or "UNKNOWN" + return "UNKNOWN" + + +def _normalize_image_record(dataset_name, collection_id, row): + image_id = row["id"] + source_name = row.get("name") or f"image-{image_id}" + prediction_label = f"nv{image_id}-{_slugify(source_name)}" + return { + "dataset": dataset_name, + "collection_id": collection_id, + "image_id": image_id, + "prediction_label": prediction_label, + "source_name": source_name, + "file_url": row.get("file"), + "map_type": row.get("map_type"), + "modality": row.get("modality"), + "task_family": _infer_dataset_task_family(dataset_name, row), + "task_name": row.get("cognitive_paradigm_cogatlas"), + "task_cogatlas_id": row.get("cognitive_paradigm_cogatlas_id"), + "task_code": row.get("task"), + "contrast_definition": row.get("contrast_definition"), + } + + +def _download_file(url, destination, timeout=DEFAULT_TIMEOUT, overwrite=False): + destination = Path(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + if destination.exists() and not overwrite: + return destination + + tmp_destination = destination.with_suffix(destination.suffix + ".part") + with requests.get(url, stream=True, timeout=timeout) as response: + response.raise_for_status() + with tmp_destination.open("wb") as file_obj: + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + file_obj.write(chunk) + tmp_destination.replace(destination) + return destination + + +def _download_images(records, output_dir, num_workers=8, timeout=DEFAULT_TIMEOUT, overwrite=False): + output_dir = Path(output_dir) + outputs = [] + with ThreadPoolExecutor(max_workers=max(1, num_workers)) as executor: + future_to_record = {} + for record in records: + if not record.get("file_url"): + continue + destination = output_dir / f"{record['prediction_label']}.nii.gz" + future = executor.submit( + _download_file, + record["file_url"], + destination, + timeout, + overwrite, + ) + future_to_record[future] = (record, destination) + + for future in tqdm(as_completed(future_to_record), total=len(future_to_record), desc="download images"): + record, destination = future_to_record[future] + future.result() + outputs.append( + { + **record, + "local_path": str(destination), + "filename": destination.name, + } + ) + + outputs_df = pd.DataFrame(outputs).sort_values(by=["collection_id", "image_id"]).reset_index(drop=True) + return outputs_df + + +def _write_collection_metadata(dataset_name, collection_id, collection_meta, records, destination_root): + destination_root = Path(destination_root) + destination_root.mkdir(parents=True, exist_ok=True) + with (destination_root / f"collection-{collection_id}.json").open("w") as file_obj: + json.dump(collection_meta, file_obj, indent=2) + pd.DataFrame(records).to_csv(destination_root / f"collection-{collection_id}_manifest.csv", index=False) + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Download and normalize NeuroVault collections.") + parser.add_argument( + "--project_dir", + dest="project_dir", + default=op.abspath(op.join(op.dirname(__file__), "..")), + help="Path to the repository root.", + ) + parser.add_argument( + "--datasets", + dest="datasets", + nargs="+", + default=["cnp", "ibc"], + choices=["cnp", "ibc"], + help="Datasets to prepare.", + ) + parser.add_argument( + "--ibc_collections", + dest="ibc_collections", + nargs="+", + type=int, + default=[2138, 6618], + help="IBC NeuroVault collection IDs to index.", + ) + parser.add_argument( + "--cnp_collections", + dest="cnp_collections", + nargs="+", + type=int, + default=[2606], + help="CNP NeuroVault collection IDs to index.", + ) + parser.add_argument( + "--download_ibc_images", + dest="download_ibc_images", + action="store_true", + help="Download IBC NIfTI images in addition to manifests.", + ) + parser.add_argument( + "--download_cnp_images", + dest="download_cnp_images", + action="store_true", + help="Download CNP NIfTI images in addition to manifests.", + ) + parser.add_argument( + "--ibc_download_collections", + dest="ibc_download_collections", + nargs="+", + type=int, + default=[2138], + help="IBC collection IDs whose images should be downloaded when --download_ibc_images is set.", + ) + parser.add_argument( + "--num_workers", + dest="num_workers", + type=int, + default=8, + help="Concurrent download worker count.", + ) + parser.add_argument( + "--timeout", + dest="timeout", + type=int, + default=DEFAULT_TIMEOUT, + help="Per-request timeout in seconds.", + ) + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrite existing downloaded files.", + ) + return parser + + +def main( + project_dir, + datasets, + ibc_collections, + cnp_collections, + download_ibc_images=False, + download_cnp_images=False, + ibc_download_collections=None, + num_workers=8, + timeout=DEFAULT_TIMEOUT, + overwrite=False, +): + project_dir = op.abspath(project_dir) + ibc_download_collections = [2138] if ibc_download_collections is None else ibc_download_collections + + for dataset_name in datasets: + collection_ids = ibc_collections if dataset_name == "ibc" else cnp_collections + metadata_dir = Path(project_dir) / "data" / dataset_name / "metadata" + image_dir = Path(project_dir) / "data" / dataset_name + metadata_dir.mkdir(parents=True, exist_ok=True) + image_dir.mkdir(parents=True, exist_ok=True) + + all_records = [] + for collection_id in tqdm(collection_ids, desc=f"{dataset_name} collections"): + collection_meta = _fetch_collection_metadata(collection_id, timeout=timeout) + records = [ + _normalize_image_record(dataset_name, collection_id, row) + for row in _iter_collection_images(collection_id, timeout=timeout) + ] + _write_collection_metadata(dataset_name, collection_id, collection_meta, records, metadata_dir) + all_records.extend(records) + + manifest_df = pd.DataFrame(all_records).sort_values(by=["collection_id", "image_id"]).reset_index(drop=True) + manifest_fn = image_dir / "neurovault_manifest.csv" + manifest_df.to_csv(manifest_fn, index=False) + + should_download = (dataset_name == "ibc" and download_ibc_images) or ( + dataset_name == "cnp" and download_cnp_images + ) + if not should_download: + continue + + if dataset_name == "ibc": + download_records = [ + record for record in all_records if record["collection_id"] in set(ibc_download_collections) + ] + else: + download_records = list(all_records) + + downloaded_df = _download_images( + download_records, + output_dir=image_dir, + num_workers=num_workers, + timeout=timeout, + overwrite=overwrite, + ) + downloaded_df.to_csv(image_dir / "mapping.csv", index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/reproduce_hcp_benchmark.py b/jobs/reproduce_hcp_benchmark.py new file mode 100644 index 0000000..a00c996 --- /dev/null +++ b/jobs/reproduce_hcp_benchmark.py @@ -0,0 +1,642 @@ +import argparse +import multiprocessing as mp +import itertools +import json +import os +import os.path as op +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import pandas as pd +import requests +from tqdm.auto import tqdm +from jobs.utils import ( + DEFAULT_MODEL_IDS, + DEFAULT_SECTIONS, + build_cognitiveatlas, + load_decoding_resources, +) + +HCP_COLLECTION_ID = 457 + +HCP_REPRESENTATIVE_MAPS = [ + { + "domain_key": "emotion", + "task_code": "EMOTION", + "image_id": 3128, + "contrast_label": "Faces vs Shapes", + "filename": "tfMRI_EMOTION_FACES-SHAPES_zstat1.nii.gz", + "task_name": "emotion processing fMRI task paradigm", + }, + { + "domain_key": "gambling", + "task_code": "GAMBLING", + "image_id": 3137, + "contrast_label": "Reward", + "filename": "tfMRI_GAMBLING_REWARD_zstat1.nii.gz", + "task_name": "gambling fMRI task paradigm", + }, + { + "domain_key": "language", + "task_code": "LANGUAGE", + "image_id": 3142, + "contrast_label": "Story vs Math", + "filename": "tfMRI_LANGUAGE_STORY-MATH_zstat1.nii.gz", + "task_name": "language processing fMRI task paradigm", + }, + { + "domain_key": "motor", + "task_code": "MOTOR", + "image_id": 3152, + "contrast_label": "Average", + "filename": "tfMRI_MOTOR_AVG_zstat1.nii.gz", + "task_name": "motor fMRI task paradigm", + }, + { + "domain_key": "relational", + "task_code": "RELATIONAL", + "image_id": 8820, + "contrast_label": "Relational vs Match", + "filename": "tfMRI_RELATIONAL_REL-MATCH.nii_tstat1.nii.gz", + "task_name": "relational processing fMRI task paradigm", + }, + { + "domain_key": "social", + "task_code": "SOCIAL", + "image_id": 3180, + "contrast_label": "TOM vs Random", + "filename": "tfMRI_SOCIAL_TOM-RANDOM_zstat1.nii.gz", + "task_name": "social cognition (theory of mind) fMRI task paradigm", + }, + { + "domain_key": "working_memory", + "task_code": "WM", + "image_id": 3190, + "contrast_label": "2-Back vs 0-Back", + "filename": "tfMRI_WM_2BK-0BK_zstat1.nii.gz", + "task_name": "working memory fMRI task paradigm", + }, +] + + +def _download_file(url, destination, timeout=120): + destination = Path(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + if destination.exists(): + return destination + + with requests.get(url, stream=True, timeout=timeout) as response: + response.raise_for_status() + with destination.open("wb") as file_obj: + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + file_obj.write(chunk) + return destination + + +def _get_parser(): + parser = argparse.ArgumentParser(description="Reproduce the HCP benchmark with published NiCLIP assets") + parser.add_argument( + "--workdir", + dest="workdir", + default="/tmp/brain-decoder-hcp", + help="Root directory for downloaded assets and benchmark outputs.", + ) + parser.add_argument( + "--sections", + dest="sections", + nargs="+", + default=list(DEFAULT_SECTIONS), + help="Text sections to evaluate.", + ) + parser.add_argument( + "--model_ids", + dest="model_ids", + nargs="+", + default=list(DEFAULT_MODEL_IDS), + help="Embedding model identifiers to evaluate.", + ) + parser.add_argument( + "--sources", + dest="sources", + nargs="+", + default=["cogatlasred", "cogatlas"], + help="Vocabulary sources to evaluate.", + ) + parser.add_argument( + "--sub_categories", + dest="sub_categories", + nargs="+", + default=["combined", "names"], + help="Vocabulary embedding variants to evaluate.", + ) + parser.add_argument( + "--topk", + dest="topk", + type=int, + default=20, + help="Number of predictions to keep per image.", + ) + parser.add_argument( + "--device", + dest="device", + default=None, + help="Optional device override.", + ) + parser.add_argument( + "--devices", + dest="devices", + nargs="+", + default=None, + help="Optional list of devices to use in parallel (e.g., cuda:0 cuda:1 cuda:2 cuda:3).", + ) + parser.add_argument( + "--num_workers", + dest="num_workers", + type=int, + default=None, + help="Number of parallel config workers. Defaults to the number of devices in --devices.", + ) + parser.add_argument( + "--max_images", + dest="max_images", + type=int, + default=None, + help="Optional cap on the number of benchmark maps. Useful for runtime estimation.", + ) + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrite existing prediction CSVs instead of skipping them.", + ) + parser.add_argument( + "--skip_eval", + dest="skip_eval", + action="store_true", + help="Skip aggregate evaluation. Useful when running a subset of images for timing.", + ) + parser.add_argument( + "--fetch_timeout", + dest="fetch_timeout", + type=int, + default=300, + help="Per-request timeout in seconds for OSF and NeuroVault downloads.", + ) + return parser + + +def _prepare_assets(workdir, fetch_timeout=300): + from braindec import fetcher + + fetcher.download_asset("brain_mask_mni152_2mm", destination_root=workdir, timeout=fetch_timeout) + fetcher.download_asset("cognitive_atlas", destination_root=workdir, timeout=fetch_timeout) + fetcher.download_osf_folder("data/vocabulary", destination_root=workdir, timeout=fetch_timeout) + + +def _prepare_minimal_assets(workdir, sections, model_ids, sources, sub_categories, fetch_timeout=300): + from braindec import fetcher + from jobs.utils import get_model_name + + fetcher.download_asset("brain_mask_mni152_2mm", destination_root=workdir, timeout=fetch_timeout) + fetcher.download_asset("cognitive_atlas", destination_root=workdir, timeout=fetch_timeout) + + required_pubmed = set() + required_baseline = set() + required_vocabulary = set() + + for section in sections: + for model_id in model_ids: + model_name = get_model_name(model_id) + required_pubmed.add( + f"model-clip_section-{section}_embedding-{model_name}_best.pth" + ) + + for source, sub_category, section, model_id in itertools.product( + sources, + sub_categories, + sections, + model_ids, + ): + model_name = get_model_name(model_id) + vocabulary_label = f"vocabulary-{source}_task-{sub_category}_embedding-{model_name}" + required_vocabulary.add(f"vocabulary-{source}_task.txt") + required_vocabulary.add(f"{vocabulary_label}.npy") + required_vocabulary.add(f"{vocabulary_label}_section-{section}_prior.npy") + + if "names" in sub_categories: + for source, section, model_id in itertools.product(sources, sections, model_ids): + model_name = get_model_name(model_id) + baseline_label = f"{source}-task_embedding-{model_name}_section-{section}" + required_baseline.add(f"model-gclda_{baseline_label}.pkl") + required_baseline.add(f"model-neurosynth_{baseline_label}.pkl") + + def _download_from_listing(folder_path, required_names, allow_csv_prior_fallback=False): + listing = fetcher.list_remote_assets(remote_path=folder_path, timeout=fetch_timeout) + items_by_name = {item["attributes"]["name"]: item for item in listing} + + for name in sorted(required_names): + fallback_name = None + if name.endswith("_prior.npy") and allow_csv_prior_fallback and name not in items_by_name: + fallback_name = name[:-4] + ".csv" + + selected_name = name if name in items_by_name else fallback_name + if selected_name is None: + raise FileNotFoundError(f"{folder_path}/{name} was not found in published OSF assets.") + + print(f"Downloading {folder_path}/{selected_name}", flush=True) + fetcher.download_osf_file( + items_by_name[selected_name]["id"], + destination_root=workdir, + timeout=fetch_timeout, + ) + + _download_from_listing("results/pubmed", required_pubmed) + _download_from_listing("results/baseline", required_baseline) + _download_from_listing("data/vocabulary", required_vocabulary, allow_csv_prior_fallback=True) + + +def _prepare_hcp_inputs(workdir, fetch_timeout=300): + image_dir = Path(workdir) / "data" / "hcp" / "neurovault" + image_dir.mkdir(parents=True, exist_ok=True) + rows = [] + for item in HCP_REPRESENTATIVE_MAPS: + url = f"https://neurovault.org/media/images/{HCP_COLLECTION_ID}/{item['filename']}" + local_path = _download_file(url, image_dir / item["filename"], timeout=fetch_timeout) + rows.append({**item, "image_path": str(local_path)}) + return pd.DataFrame(rows) + + +def _prepare_nilearn_assets(workdir): + from nilearn import datasets + + difumo_kwargs = { + "dimension": 512, + "resolution_mm": 2, + "data_dir": op.join(workdir, "data", "nilearn"), + } + try: + datasets.fetch_atlas_difumo(legacy_format=False, **difumo_kwargs) + except TypeError: + datasets.fetch_atlas_difumo(**difumo_kwargs) + + +def _build_ground_truth(mapping_df, data_dir): + concept_to_process_fn = op.join(data_dir, "cognitive_atlas", "concept_to_process.json") + cognitiveatlas = build_cognitiveatlas(data_dir, reduced=True, concept_to_process_fn=concept_to_process_fn) + + records = {} + doc_rows = [] + for row in mapping_df.to_dict(orient="records"): + task_idx = cognitiveatlas.get_task_idx_from_names(row["task_name"]) + concept_names = cognitiveatlas.get_concept_names_from_idx(cognitiveatlas.task_to_concept_idxs[task_idx]) + process_names = cognitiveatlas.get_process_names_from_idx(cognitiveatlas.task_to_process_idxs[task_idx]) + records[row["domain_key"]] = { + "task": [row["task_name"]], + "concept": concept_names.tolist(), + "domain": process_names.tolist(), + } + doc_rows.append( + { + **row, + "concepts": "; ".join(concept_names.tolist()), + "domains": "; ".join(process_names.tolist()), + } + ) + + return records, pd.DataFrame(doc_rows) + + +def _iter_prediction_jobs(sources, sections, model_ids, sub_categories): + for source, section, model_id, sub_category in itertools.product( + sources, + sections, + model_ids, + sub_categories, + ): + yield { + "source": source, + "section": section, + "model_id": model_id, + "sub_category": sub_category, + } + + +def _resolve_devices(device=None, devices=None): + if devices: + return devices + + if device is not None: + return [device] + + import torch + + if torch.cuda.is_available(): + return [f"cuda:{device_idx}" for device_idx in range(torch.cuda.device_count())] + + return ["cpu"] + + +def _run_prediction_job(job, mapping_records, workdir, topk, device, overwrite=False, show_progress=False): + import nibabel as nib + from nilearn.image import resample_to_img + from nimare.annotate.gclda import GCLDAModel + from nimare.decode.continuous import CorrelationDecoder, gclda_decode_map + + from braindec.embedding import ImageEmbedding + from braindec.model import build_model + from braindec.predict import image_to_labels_hierarchical + from braindec.utils import _get_device, images_have_same_fov + + data_dir = op.join(workdir, "data") + results_dir = op.join(workdir, "results") + mask_img = nib.load(op.join(data_dir, "MNI152_2x2x2_brainmask.nii.gz")) + device = _get_device() if device is None else device + + output_dir = op.join(results_dir, "predictions_hcp_nv") + os.makedirs(output_dir, exist_ok=True) + + def _resample_for_reference(image, reference_img): + if reference_img is None or images_have_same_fov(image, reference_img): + return image + return resample_to_img(image, reference_img) + + source = job["source"] + section = job["section"] + model_id = job["model_id"] + sub_category = job["sub_category"] + reduced = source == "cogatlasred" + concept_to_process_fn = op.join(data_dir, "cognitive_atlas", "concept_to_process.json") + cognitiveatlas = build_cognitiveatlas(data_dir, reduced=reduced, concept_to_process_fn=concept_to_process_fn) + resources = load_decoding_resources( + op.join(workdir, "results"), + op.join(data_dir, "vocabulary"), + source, + "task", + sub_category, + model_id, + section, + ) + model = build_model(resources["model_path"], device=device) + image_emb_gene = ImageEmbedding( + standardize=False, + nilearn_dir=op.join(data_dir, "nilearn"), + space="MNI152", + ) + gclda_model = None + ns_decoder = None + ns_masker = None + + if sub_category == "names": + baseline_label = f"{source}-task_embedding-{resources['model_name']}_section-{section}" + gclda_model = GCLDAModel.load(op.join(results_dir, "baseline", f"model-gclda_{baseline_label}.pkl")) + ns_decoder = CorrelationDecoder.load( + op.join(results_dir, "baseline", f"model-neurosynth_{baseline_label}.pkl") + ) + if hasattr(ns_decoder, "results_"): + ns_masker = getattr(ns_decoder.results_, "masker", None) + if ns_masker is None and hasattr(ns_decoder, "masker"): + ns_masker = ns_decoder.masker + if ns_masker is not None and not hasattr(ns_masker, "clean_args_"): + clean_kwargs = getattr(ns_masker, "clean_kwargs", None) + ns_masker.clean_args_ = {} if clean_kwargs is None else clean_kwargs + + iterator = mapping_records + if show_progress: + iterator = tqdm( + mapping_records, + total=len(mapping_records), + desc=( + f"images {source}/{section}/{resources['model_name']}/{sub_category}" + ), + leave=False, + ) + + started_at = time.perf_counter() + brainclip_outputs = 0 + gclda_outputs = 0 + ns_outputs = 0 + + for row in iterator: + img = nib.load(row["image_path"]) + if not images_have_same_fov(img, mask_img): + img = resample_to_img(img, mask_img) + + file_base_name = f"{row['task_code']}_{resources['vocabulary_label']}_section-{section}" + task_out_fn = op.join(output_dir, f"{file_base_name}_pred-task_brainclip.csv") + concept_out_fn = op.join(output_dir, f"{file_base_name}_pred-concept_brainclip.csv") + process_out_fn = op.join(output_dir, f"{file_base_name}_pred-process_brainclip.csv") + + if overwrite or not op.exists(task_out_fn): + task_prob_df, concept_prob_df, process_prob_df = image_to_labels_hierarchical( + img, + resources["model_path"], + resources["vocabulary"], + resources["vocabulary_emb"], + resources["vocabulary_prior"], + cognitiveatlas, + topk=topk, + logit_scale=20.0, + model=model, + image_emb_gene=image_emb_gene, + data_dir=data_dir, + device=device, + ) + task_prob_df.to_csv(task_out_fn, index=False) + concept_prob_df.to_csv(concept_out_fn, index=False) + process_prob_df.to_csv(process_out_fn, index=False) + brainclip_outputs += 3 + + if sub_category != "names": + continue + + ns_out_fn = op.join(output_dir, f"{file_base_name}_pred-task_neurosynth.csv") + gclda_out_fn = op.join(output_dir, f"{file_base_name}_pred-task_gclda.csv") + + if overwrite or not op.exists(gclda_out_fn): + gclda_img = _resample_for_reference(img, getattr(gclda_model, "mask", None)) + gclda_predictions_df, _ = gclda_decode_map(gclda_model, gclda_img) + gclda_predictions_df = gclda_predictions_df.sort_values(by="Weight", ascending=False).head(topk) + gclda_predictions_df = gclda_predictions_df.reset_index() + gclda_predictions_df.columns = ["pred", "weight"] + gclda_predictions_df.to_csv(gclda_out_fn, index=False) + gclda_outputs += 1 + + if overwrite or not op.exists(ns_out_fn): + ns_mask = None + if ns_masker is not None: + ns_mask = getattr(ns_masker, "mask_img", None) + if ns_mask is None: + ns_mask = getattr(ns_masker, "mask_img_", None) + ns_img = _resample_for_reference(img, ns_mask) + ns_predictions_df = ns_decoder.transform(ns_img) + feature_group = f"{source}-task_section-{section}_annot-tfidf__" + vocabulary_names = [feature.replace(feature_group, "") for feature in ns_predictions_df.index.values] + ns_predictions_df.index = vocabulary_names + ns_predictions_df = ns_predictions_df.sort_values(by="r", ascending=False).head(topk) + ns_predictions_df = ns_predictions_df.reset_index() + ns_predictions_df.columns = ["pred", "corr"] + ns_predictions_df.to_csv(ns_out_fn, index=False) + ns_outputs += 1 + + elapsed_seconds = time.perf_counter() - started_at + return { + **job, + "device": str(device), + "num_images": len(mapping_records), + "elapsed_seconds": elapsed_seconds, + "seconds_per_image": elapsed_seconds / max(len(mapping_records), 1), + "brainclip_outputs_written": brainclip_outputs, + "gclda_outputs_written": gclda_outputs, + "neurosynth_outputs_written": ns_outputs, + } + + +def _run_predictions( + mapping_df, + workdir, + sections, + model_ids, + sources, + sub_categories, + topk, + device, + devices, + num_workers, + overwrite, +): + jobs = list(_iter_prediction_jobs(sources, sections, model_ids, sub_categories)) + mapping_records = mapping_df.to_dict(orient="records") + resolved_devices = _resolve_devices(device=device, devices=devices) + num_workers = len(resolved_devices) if num_workers is None else num_workers + num_workers = max(1, min(num_workers, len(jobs))) + + summaries = [] + started_at = time.perf_counter() + + if num_workers == 1: + for job in tqdm(jobs, desc="configs", unit="config"): + summaries.append( + _run_prediction_job( + job, + mapping_records, + workdir, + topk=topk, + device=resolved_devices[0], + overwrite=overwrite, + show_progress=True, + ) + ) + else: + mp_context = mp.get_context("spawn") + with ProcessPoolExecutor(max_workers=num_workers, mp_context=mp_context) as executor: + future_to_job = {} + for job_idx, job in enumerate(jobs): + assigned_device = resolved_devices[job_idx % len(resolved_devices)] + future = executor.submit( + _run_prediction_job, + job, + mapping_records, + workdir, + topk, + assigned_device, + overwrite, + False, + ) + future_to_job[future] = job + + for future in tqdm(as_completed(future_to_job), total=len(future_to_job), desc="configs", unit="config"): + summaries.append(future.result()) + + elapsed_seconds = time.perf_counter() - started_at + summary_df = pd.DataFrame(summaries).sort_values( + by=["source", "section", "model_id", "sub_category"] + ) + summary_df["total_wall_seconds"] = elapsed_seconds + runtime_summary_fn = op.join(workdir, "results", "hcp_prediction_runtime.csv") + summary_df.to_csv(runtime_summary_fn, index=False) + + return op.join(workdir, "results", "predictions_hcp_nv"), runtime_summary_fn + + +def main( + workdir, + sections, + model_ids, + sources, + sub_categories, + topk=20, + device=None, + devices=None, + num_workers=None, + max_images=None, + overwrite=False, + skip_eval=False, + fetch_timeout=300, +): + workdir = op.abspath(workdir) + _prepare_minimal_assets( + workdir, + sections, + model_ids, + sources, + sub_categories, + fetch_timeout=fetch_timeout, + ) + mapping_df = _prepare_hcp_inputs(workdir, fetch_timeout=fetch_timeout) + _prepare_nilearn_assets(workdir) + if max_images is not None: + mapping_df = mapping_df.head(max_images).copy() + + ground_truth, documentation_df = _build_ground_truth(mapping_df, op.join(workdir, "data")) + ground_truth_path = op.join(workdir, "data", "hcp", "ground_truth.json") + os.makedirs(op.dirname(ground_truth_path), exist_ok=True) + with open(ground_truth_path, "w") as file_obj: + json.dump(ground_truth, file_obj, indent=2) + + documentation_df.to_csv(op.join(workdir, "data", "hcp", "benchmark_mapping.csv"), index=False) + + prediction_dir, runtime_summary_fn = _run_predictions( + mapping_df, + workdir, + sections=sections, + model_ids=model_ids, + sources=sources, + sub_categories=sub_categories, + topk=topk, + device=device, + devices=devices, + num_workers=num_workers, + overwrite=overwrite, + ) + + if skip_eval: + print(f"Skipped evaluation. Runtime summary written to {runtime_summary_fn}", flush=True) + return + + from jobs.decoding_eval import main as eval_main + + eval_main( + data_dir=op.join(workdir, "data"), + results_dir=op.join(workdir, "results"), + sections=sections, + model_ids=model_ids, + sources=sources, + categories=["task"], + sub_categories=sub_categories, + models=["neurosynth", "gclda", "brainclip"], + prediction_dir=prediction_dir, + image_dir=op.join(workdir, "data", "hcp", "neurovault"), + ground_truth_fn=ground_truth_path, + output_fn=op.join(workdir, "results", "eval-hcp-group_results.csv"), + ) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/roi_followup.py b/jobs/roi_followup.py new file mode 100644 index 0000000..15378de --- /dev/null +++ b/jobs/roi_followup.py @@ -0,0 +1,74 @@ +"""Aggregate ROI decoding outputs or document missing ROI assets.""" + +import argparse +import json +import os +import os.path as op +from glob import glob + +import pandas as pd + +from jobs.utils import resolve_project_paths + + +def _get_parser(): + parser = argparse.ArgumentParser(description="ROI follow-up summary for striatum and related seeds.") + parser.add_argument("--project_dir", dest="project_dir", default=None) + parser.add_argument("--data_dir", dest="data_dir", default=None) + parser.add_argument("--results_dir", dest="results_dir", default=None) + parser.add_argument("--prediction_dir", dest="prediction_dir", default=None) + parser.add_argument("--output_fn", dest="output_fn", required=True) + return parser + + +def _top_prediction(csv_path): + df = pd.read_csv(csv_path) + first = df.iloc[0] + value_col = "prob" if "prob" in df.columns else "weight" if "weight" in df.columns else "corr" + return {"pred": first["pred"], "score": float(first[value_col])} + + +def main(project_dir=None, data_dir=None, results_dir=None, prediction_dir=None, output_fn=None): + _, data_dir, results_dir = resolve_project_paths(project_dir, data_dir, results_dir) + prediction_dir = ( + op.join(results_dir, "predictions_rois") if prediction_dir is None else op.abspath(prediction_dir) + ) + seed_dir = op.join(data_dir, "seed-regions") + + if not op.isdir(seed_dir) and not op.isdir(prediction_dir): + output_df = pd.DataFrame( + [ + { + "status": "blocked", + "reason": "roi_seed_maps_missing", + "expected_seed_dir": seed_dir, + "expected_prediction_dir": prediction_dir, + } + ] + ) + else: + rows = [] + for task_csv in sorted(glob(op.join(prediction_dir, "*_pred-task_brainclip.csv"))): + stem = op.basename(task_csv).replace("_pred-task_brainclip.csv", "") + concept_csv = task_csv.replace("_pred-task_brainclip.csv", "_pred-concept_brainclip.csv") + process_csv = task_csv.replace("_pred-task_brainclip.csv", "_pred-process_brainclip.csv") + row = {"roi_label": stem, "status": "ready"} + row["top_task_json"] = json.dumps(_top_prediction(task_csv)) + if op.exists(concept_csv): + row["top_concept_json"] = json.dumps(_top_prediction(concept_csv)) + if op.exists(process_csv): + row["top_domain_json"] = json.dumps(_top_prediction(process_csv)) + rows.append(row) + output_df = pd.DataFrame(rows) + + os.makedirs(op.dirname(op.abspath(output_fn)), exist_ok=True) + output_df.to_csv(output_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/snr_sweep.py b/jobs/snr_sweep.py new file mode 100644 index 0000000..1d8339a --- /dev/null +++ b/jobs/snr_sweep.py @@ -0,0 +1,69 @@ +"""Run or document the HCP subject-level group-size sensitivity analysis.""" + +import argparse +import os +import os.path as op + +import pandas as pd + +from jobs.utils import resolve_project_paths + + +def _get_parser(): + parser = argparse.ArgumentParser( + description="Group-size sensitivity analysis for subject-level HCP maps." + ) + parser.add_argument("--project_dir", dest="project_dir", default=None) + parser.add_argument("--data_dir", dest="data_dir", default=None) + parser.add_argument("--results_dir", dest="results_dir", default=None) + parser.add_argument( + "--subject_dir", + dest="subject_dir", + default=None, + help="Directory containing subject-level HCP maps.", + ) + parser.add_argument( + "--output_fn", + dest="output_fn", + required=True, + help="CSV report path.", + ) + return parser + + +def main(project_dir=None, data_dir=None, results_dir=None, subject_dir=None, output_fn=None): + _, data_dir, _ = resolve_project_paths(project_dir, data_dir, results_dir) + subject_dir = op.join(data_dir, "hcp_subject") if subject_dir is None else op.abspath(subject_dir) + + if not op.isdir(subject_dir): + report_df = pd.DataFrame( + [ + { + "status": "blocked", + "reason": "subject_level_hcp_maps_missing", + "expected_path": subject_dir, + } + ] + ) + else: + report_df = pd.DataFrame( + [ + { + "status": "ready", + "reason": "subject_level_hcp_maps_present", + "expected_path": subject_dir, + } + ] + ) + + os.makedirs(op.dirname(op.abspath(output_fn)), exist_ok=True) + report_df.to_csv(output_fn, index=False) + + +def _main(argv=None): + options = _get_parser().parse_args(argv) + main(**vars(options)) + + +if __name__ == "__main__": + _main() diff --git a/jobs/utils.py b/jobs/utils.py index fe77f48..741a6c0 100644 --- a/jobs/utils.py +++ b/jobs/utils.py @@ -1,4 +1,41 @@ +"""Shared helpers for analysis and decoding jobs.""" + +import argparse +import ast +import json +import os.path as op + import numpy as np +import pandas as pd + +DEFAULT_MODEL_IDS = [ + "BrainGPT/BrainGPT-7B-v0.2", + "mistralai/Mistral-7B-v0.1", + "BrainGPT/BrainGPT-7B-v0.1", + "meta-llama/Llama-2-7b-chat-hf", +] +DEFAULT_SECTIONS = ["abstract", "body"] + + +def _read_prior(prior_fn): + if op.exists(prior_fn): + return np.load(prior_fn) + + csv_prior_fn = f"{op.splitext(prior_fn)[0]}.csv" + if not op.exists(csv_prior_fn): + raise FileNotFoundError(f"Could not find prior file {prior_fn} or CSV fallback {csv_prior_fn}.") + + prior_df = pd.read_csv(csv_prior_fn) + if "prior" in prior_df.columns: + prior_values = prior_df["prior"].to_numpy() + else: + numeric_columns = prior_df.select_dtypes(include="number").columns + if len(numeric_columns) == 0: + raise ValueError(f"CSV prior fallback {csv_prior_fn} does not contain a numeric prior column.") + prior_values = prior_df[numeric_columns[-1]].to_numpy() + + np.save(prior_fn, prior_values) + return prior_values def _read_vocabulary(vocabulary_fn, vocabulary_emb_fn, vocabulary_prior_fn=None): @@ -6,6 +43,221 @@ def _read_vocabulary(vocabulary_fn, vocabulary_emb_fn, vocabulary_prior_fn=None) vocabulary = [line.strip() for line in f] if vocabulary_prior_fn is not None: - return vocabulary, np.load(vocabulary_emb_fn), np.load(vocabulary_prior_fn) - else: - return vocabulary, np.load(vocabulary_emb_fn) + return vocabulary, np.load(vocabulary_emb_fn), _read_prior(vocabulary_prior_fn) + + return vocabulary, np.load(vocabulary_emb_fn) + + +def str_to_bool(value): + """Parse common string forms of booleans for argparse.""" + if isinstance(value, bool): + return value + + normalized = value.strip().lower() + if normalized in {"1", "true", "t", "yes", "y"}: + return True + if normalized in {"0", "false", "f", "no", "n"}: + return False + + raise argparse.ArgumentTypeError(f"Cannot interpret boolean value from {value!r}.") + + +def get_default_project_dir(): + """Return the repository root based on the current jobs directory.""" + return op.abspath(op.join(op.dirname(__file__), "..")) + + +def resolve_project_paths(project_dir=None, data_dir=None, results_dir=None): + """Resolve project-relative data and results directories.""" + project_dir = get_default_project_dir() if project_dir is None else op.abspath(project_dir) + data_dir = op.join(project_dir, "data") if data_dir is None else op.abspath(data_dir) + results_dir = op.join(project_dir, "results") if results_dir is None else op.abspath(results_dir) + return project_dir, data_dir, results_dir + + +def get_source(reduced): + """Return the ontology source label for a reduced/full vocabulary.""" + return "cogatlasred" if reduced else "cogatlas" + + +def get_model_name(model_id): + """Extract the short model name used in filenames.""" + return model_id.split("/")[-1] + + +def strip_nii_suffix(path_or_name): + """Return a filename stem while preserving inner dots in `.nii.gz` names.""" + filename = op.basename(path_or_name) + if filename.endswith(".nii.gz"): + return filename[:-7] + return op.splitext(filename)[0] + + +def infer_prediction_label(path_or_name, delimiter="_", token_index=None): + """Infer the prediction label used in output filenames from an image filename.""" + stem = strip_nii_suffix(path_or_name) + if token_index is None: + return stem + + parts = stem.split(delimiter) + if token_index >= len(parts) or token_index < -len(parts): + raise ValueError( + f"Cannot extract token index {token_index} from {path_or_name!r} using delimiter {delimiter!r}." + ) + return parts[token_index] + + +def parse_name_list(value): + """Parse a CSV/JSON-ish field into a normalized list of label strings.""" + if value is None: + return [] + + if isinstance(value, float) and np.isnan(value): + return [] + + if isinstance(value, np.ndarray): + value = value.tolist() + + if isinstance(value, (list, tuple)): + return [str(item).strip() for item in value if str(item).strip()] + + if isinstance(value, str): + normalized = value.strip() + if not normalized: + return [] + + if normalized[0] in {"[", "("}: + try: + parsed = ast.literal_eval(normalized) + except (SyntaxError, ValueError): + parsed = None + if isinstance(parsed, (list, tuple)): + return [str(item).strip() for item in parsed if str(item).strip()] + + for delimiter in (";", "|", ","): + if delimiter in normalized: + return [item.strip() for item in normalized.split(delimiter) if item.strip()] + + return [normalized] + + return [str(value).strip()] + + +def add_common_job_args(parser): + """Add the shared CLI surface used by analysis jobs.""" + parser.add_argument( + "--project_dir", + dest="project_dir", + default=get_default_project_dir(), + help="Path to the repository root. Defaults to the parent of the jobs directory.", + ) + parser.add_argument( + "--data_dir", + dest="data_dir", + default=None, + help="Optional explicit path to the data directory. Defaults to /data.", + ) + parser.add_argument( + "--results_dir", + dest="results_dir", + default=None, + help="Optional explicit path to the results directory. Defaults to /results.", + ) + parser.add_argument( + "--sections", + dest="sections", + nargs="+", + default=list(DEFAULT_SECTIONS), + help="Text sections to evaluate. Defaults to abstract and body.", + ) + parser.add_argument( + "--model_ids", + dest="model_ids", + nargs="+", + default=list(DEFAULT_MODEL_IDS), + help="One or more embedding model identifiers to evaluate.", + ) + parser.add_argument( + "--device", + dest="device", + default=None, + help="Device to use for computation. Defaults to the braindec device helper.", + ) + parser.add_argument( + "--reduced", + dest="reduced", + type=str_to_bool, + default=True, + help="Whether to use the reduced Cognitive Atlas vocabulary. Defaults to true.", + ) + parser.add_argument( + "--topk", + dest="topk", + type=int, + default=20, + help="Number of predictions to keep per image.", + ) + parser.add_argument( + "--standardize", + dest="standardize", + type=str_to_bool, + default=False, + help="Whether to standardize images before embedding.", + ) + parser.add_argument( + "--logit_scale", + dest="logit_scale", + type=float, + default=20.0, + help="Override CLIP logit scale used during decoding.", + ) + return parser + + +def build_cognitiveatlas(data_dir, reduced, concept_to_process_fn=None): + """Construct a CognitiveAtlas object from local snapshots.""" + from braindec.cogatlas import CognitiveAtlas + + reduced_tasks_fn = op.join(data_dir, "cognitive_atlas", "reduced_tasks.csv") + reduced_tasks_df = pd.read_csv(reduced_tasks_fn) if reduced else None + + concept_to_process = None + if concept_to_process_fn is not None and op.exists(concept_to_process_fn): + with open(concept_to_process_fn, "r") as file: + concept_to_process = json.load(file) + + return CognitiveAtlas( + data_dir=data_dir, + task_snapshot=op.join(data_dir, "cognitive_atlas", "task_snapshot-02-19-25.json"), + concept_snapshot=op.join(data_dir, "cognitive_atlas", "concept_extended_snapshot-02-19-25.json"), + concept_to_process=concept_to_process, + reduced_tasks=reduced_tasks_df, + ) + + +def load_decoding_resources(results_dir, voc_dir, source, category, sub_category, model_id, section): + """Load model and vocabulary artifacts for a decoding run.""" + model_name = get_model_name(model_id) + model_path = op.join( + results_dir, + "pubmed", + f"model-clip_section-{section}_embedding-{model_name}_best.pth", + ) + vocabulary_label = f"vocabulary-{source}_{category}-{sub_category}_embedding-{model_name}" + vocabulary_fn = op.join(voc_dir, f"vocabulary-{source}_{category}.txt") + vocabulary_emb_fn = op.join(voc_dir, f"{vocabulary_label}.npy") + vocabulary_prior_fn = op.join(voc_dir, f"{vocabulary_label}_section-{section}_prior.npy") + vocabulary, vocabulary_emb, vocabulary_prior = _read_vocabulary( + vocabulary_fn, + vocabulary_emb_fn, + vocabulary_prior_fn, + ) + + return { + "model_name": model_name, + "model_path": model_path, + "vocabulary_label": vocabulary_label, + "vocabulary": vocabulary, + "vocabulary_emb": vocabulary_emb, + "vocabulary_prior": vocabulary_prior, + } diff --git a/pyproject.toml b/pyproject.toml index c81c420..584aadf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,16 +34,16 @@ dependencies = [ dynamic = ["version"] [project.urls] -Homepage = "https://github.com/JulioAPeraza/braindec" +Homepage = "https://github.com/jdkent/brain-decoder" Documentation = "https://braindec.readthedocs.io/en/latest/" [project.optional-dependencies] doc = [ "sphinx>=2.0", "sphinx_rtd_theme", - "sphinx-argparse", "sphinx-copybutton", "sphinx-gallery", + "jupytext", ] plotting = ["neuromaps", "surfplot"] diff --git a/review.md b/review.md new file mode 100644 index 0000000..72e4b3f --- /dev/null +++ b/review.md @@ -0,0 +1,438 @@ +# Reviewer Analysis Plan + +This document translates the commitments in `review_response.md` into a concrete analysis plan for `brain-decoder`. + +## Overview + +The reviewer response implies three kinds of work: + +1. Reproduce and harden the existing decoding pipeline so the current HCP, IBC, and ROI analyses can be rerun reliably. +2. Add missing analyses that were explicitly promised in the response letter. +3. Package the outputs into figures, tables, and supplementary artifacts that map cleanly onto manuscript revisions. + +## Reflection Rule + +After each task in this plan is completed, write a short reflection before moving on to the next task. Each reflection should include: + +1. What worked. +2. What did not work or remained ambiguous. +3. The next checks or follow-up analyses still needed to fully answer the reviewer's question tied to that task. + +Treat these reflections as part of the deliverable for each task, not as optional notes. + +## Current Repo Status + +The repository already contains partial scaffolding for several decoding analyses: + +- `jobs/decoding_hcp_nv.py` for HCP group-map decoding. +- `jobs/decoding_eval.py` for aggregate HCP evaluation. +- `jobs/decoding_ibc.py` for IBC decoding. +- `jobs/decoding_seeds.py` for ROI decoding. +- `braindec/predict.py` for task/concept/domain decoding. +- `braindec/cogatlas.py` for ontology construction and task/concept/domain mappings. + +There are also important gaps that should be treated as prerequisites: + +- `jobs/decoding_cnp.py` is referenced in `review_response.md` but does not exist in the repo. +- `jobs/decoding_ibc.py` and `jobs/decoding_seeds.py` appear stale relative to the current `image_to_labels_hierarchical` API. +- Multiple job scripts hard-code `project_dir` and `device`, which will block reproducible reruns. +- There is no script yet for per-term analysis, null/permutation baselines, embedding geometry, or SNR/sample-size sweeps. + +## Recommended Execution Order + +1. Stabilize the decoding jobs and evaluation interfaces. +2. Rerun and document the current HCP benchmark. +3. Add cross-dataset decoding for IBC and CNP. +4. Add term-level, ontology, and chance-normalized analyses. +5. Add embedding-geometry and SNR analyses. +6. Add the targeted emotion and striatum follow-up analyses. +7. Optionally run an NSD pilot if data are available and time permits. + +## Task 0: Pipeline Hardening + +### Goal + +Make the current decoding and evaluation scripts runnable, configurable, and consistent with the current `braindec` APIs. + +### Non-goals + +- Changing model architecture. +- Re-training the CLIP model. +- Refactoring the whole repository. + +### Concrete Implementation Plan + +1. Update `jobs/decoding_hcp_nv.py`, `jobs/decoding_ibc.py`, `jobs/decoding_seeds.py`, and `jobs/decoding_eval.py` to accept CLI arguments for `project_dir`, `data_dir`, `results_dir`, `device`, `section`, `model_id`, and `reduced`. +2. Fix stale calls to `image_to_labels_hierarchical` so all jobs pass a `CognitiveAtlas` object rather than raw arrays. +3. Remove hard-coded `mps` device usage and default to a CLI/device helper. +4. Standardize output directory layout across HCP, IBC, CNP, ROI, and future analyses. +5. Verify that ground-truth files are loaded from the correct dataset-specific locations. +6. Add a small README section or module docstring describing expected inputs and outputs for each analysis job. + +### Deliverables + +- Runnable decoding jobs. +- Consistent prediction CSV outputs. +- A stable base for all downstream analyses. + +## Task 1: HCP Benchmark Reproduction And Documentation + +### Goal + +Reproduce the core HCP group-map decoding results and explicitly document the task-level mapping used in the manuscript. + +### Non-goals + +- Expanding the benchmark beyond the currently selected representative task-domain contrasts. +- Fairness comparison with external IBMA models on HCP. + +### Concrete Implementation Plan + +1. Confirm the seven representative HCP contrasts used for task-level evaluation: + - Emotion: Faces vs Shapes + - Gambling: Reward vs Baseline + - Language: Story vs Math + - Motor: Average + - Relational: Relational vs Match + - Social: TOM vs Random + - Working Memory: 2-Back vs 0-Back +2. Encode this mapping in a dedicated metadata file rather than leaving it implicit in filenames. +3. Rerun `jobs/decoding_hcp_nv.py` after pipeline hardening. +4. Update `jobs/decoding_eval.py` so it reads the explicit mapping file and computes task, concept, and domain metrics deterministically. +5. Export a supplementary table containing: + - HCP image id + - selected contrast + - mapped task + - mapped concepts + - mapped domain + - top-k decoded predictions + +### Deliverables + +- Reproducible HCP benchmark outputs. +- Supplementary mapping table for task-to-concept/domain interpretation. + +## Task 2: Cross-Dataset Generalization On IBC + +### Goal + +Show that NiCLIP generalizes beyond HCP by decoding IBC statistical maps. + +### Non-goals + +- Full IBC methodological harmonization. +- Training on IBC. + +### Concrete Implementation Plan + +1. Repair and parameterize `jobs/decoding_ibc.py`. +2. Create an explicit IBC ground-truth mapping file from image names to task, concept, and domain labels. +3. Reuse the same evaluation interface as HCP so metrics are comparable. +4. Produce aggregate metrics: + - task Recall@4 + - concept Recall@4 + - domain Recall@2 +5. Break out results by task family where possible, especially emotion-related tasks if present. +6. Export per-image prediction tables for supplement use. + +### Deliverables + +- IBC prediction outputs. +- IBC summary metrics. +- Cross-dataset comparison against HCP. + +## Task 3: Cross-Dataset Generalization On CNP + +### Goal + +Add the missing CNP benchmark promised in the response letter. + +### Non-goals + +- Broad psychiatric interpretation beyond task decoding. +- Subject-level clinical modeling. + +### Concrete Implementation Plan + +1. Add a new `jobs/decoding_cnp.py` modeled after the HCP and IBC jobs. +2. Define a CNP ground-truth mapping file for tasks such as BART, PAM encoding/retrieval, SCAP, Stop Signal, and Task Switching. +3. Ensure the CNP job writes outputs into a dataset-specific predictions directory with the same CSV schema used elsewhere. +4. Add or extend an evaluation script to score CNP task/concept/domain Recall@K. +5. Summarize which CNP tasks transfer well and which do not. + +### Deliverables + +- New CNP decoding job. +- CNP evaluation outputs. +- Cross-dataset generalization section inputs. + +## Task 4: Fine-Grained Emotion Analysis + +### Goal + +Test whether NiCLIP can decode more specific emotional states than the coarse HCP emotion contrast. + +### Non-goals + +- Open-ended text generation of emotions. +- Claiming fine-grained emotion decoding if ontology coverage is weak. + +### Concrete Implementation Plan + +1. Identify emotion-specific contrasts in IBC and/or CNP that are more granular than `Faces vs Shapes`. +2. Audit the current Cognitive Atlas task and concept vocabulary for emotion-specific labels such as fear, anger, disgust, and happiness. +3. If needed, extend the vocabulary-generation path so emotion concepts and task definitions are included in the decoding vocabulary. +4. Generate vocabulary embeddings and priors for the expanded emotion-related vocabulary. +5. Decode the selected emotion maps and compare: + - coarse emotion-domain predictions + - emotion-concept predictions + - fine-grained task predictions +6. Report whether failures are due to ontology coverage, training-data sparsity, or embedding mismatch. + +### Deliverables + +- Emotion-specific decoding results. +- A clear statement of current granularity limits. + +## Task 5: Per-Term Decodability Analysis + +### Goal + +Quantify which individual tasks are actually decodable, rather than reporting only averaged benchmark scores. + +### Non-goals + +- Estimating calibrated probabilities. +- Claiming full coverage of the entire ontology if only a subset is evaluable. + +### Concrete Implementation Plan + +1. Add a new analysis script, for example `jobs/per_term_eval.py`. +2. Aggregate prediction outputs from HCP, IBC, and CNP into a single evaluation table. +3. For each task in the reduced ontology, compute: + - number of evaluation maps tied to that task + - Recall@K + - rank statistics for the ground-truth term +4. Produce: + - a per-task bar chart or heatmap + - a table of best-decoded and worst-decoded tasks +5. Explicitly mark tasks with insufficient evaluation examples so they are not over-interpreted. + +### Deliverables + +- Per-task decodability table. +- Figure summarizing term-level decoding performance. + +## Task 6: Null Baseline And Normalized Accuracy + +### Goal + +Contextualize raw task/concept/domain scores with chance-level and permutation-based baselines. + +### Non-goals + +- Formal probability calibration. +- Bayesian model comparison. + +### Concrete Implementation Plan + +1. Add a permutation procedure that shuffles ground-truth label assignments within each dataset. +2. Recompute Recall@K under the null distribution for tasks, concepts, and domains. +3. Estimate: + - empirical chance baseline + - p-value or percentile above null + - normalized accuracy = observed / chance +4. Report how many terms are above chance in the per-term analysis. +5. Add a simple table comparing observed versus chance performance across label levels. + +### Deliverables + +- Permutation-based baseline outputs. +- Chance-normalized metrics. +- Count of above-chance decodable terms. + +## Task 7: Factors Associated With Per-Term Performance + +### Goal + +Explain why some terms decode better than others. + +### Non-goals + +- Causal claims about ontology quality or article frequency. +- Full statistical modeling of all confounds. + +### Concrete Implementation Plan + +1. Build a term-level analysis table with one row per task. +2. Add candidate explanatory variables: + - number of associated training articles + - definition length + - embedding norm or other simple text-feature proxies + - map specificity measures derived from evaluation maps +3. Compute correlations and simple regression analyses between these variables and per-term Recall@K. +4. Report whether article frequency, definition quality/length, or map specificity appears most associated with performance. + +### Deliverables + +- Term-level feature table. +- Correlation/regression summary figure or table. + +## Task 8: Reduced Versus Full Cognitive Atlas Comparison + +### Goal + +Substantiate the claim that the reduced/curated ontology improves decoding. + +### Non-goals + +- Rebuilding the ontology from scratch. +- Exhaustive ontology curation. + +### Concrete Implementation Plan + +1. Export summary statistics for the full and reduced ontology: + - number of tasks + - number of concepts + - task-concept edges + - concept-domain edges +2. Rerun a matched decoding/evaluation comparison under both ontologies. +3. Compare performance for each model/configuration where the ontology changes are the only intended difference. +4. Produce a supplementary table documenting what was retained, removed, or enriched in the reduced ontology. + +### Deliverables + +- Ontology comparison table. +- Performance comparison under full versus reduced CogAt. + +## Task 9: Embedding Geometry Analysis + +### Goal + +Inspect the shared latent space to determine whether aligned task and image embeddings form meaningful structure. + +### Non-goals + +- Treating 2D projections as definitive evidence. +- Over-interpreting global geometry from UMAP/t-SNE alone. + +### Concrete Implementation Plan + +1. Add a new script, for example `jobs/embedding_geometry.py`. +2. Extract: + - task text embeddings from the decoding vocabulary + - HCP image embeddings from the trained CLIP image encoder +3. Produce 2D projections with UMAP and/or t-SNE. +4. Color points by: + - cognitive domain for task embeddings + - task family for image embeddings +5. Quantify auxiliary structure: + - within-domain versus between-domain distances + - distance from each image embedding to its matched task embedding + - comparison of geometry across BrainGPT, Mistral, and Llama variants + +### Deliverables + +- Supplementary latent-space figure. +- Distance-based summary table. + +## Task 10: SNR / Group-Size Sensitivity Analysis + +### Goal + +Test whether subject-level underperformance is mainly due to lower SNR by varying HCP group size. + +### Non-goals + +- Solving subject-level decoding in this revision. +- Training with synthetic noise augmentation unless needed as a later follow-up. + +### Concrete Implementation Plan + +1. Identify or generate HCP subject-level maps for the benchmark tasks. +2. Build group-average maps for subset sizes `N = {5, 10, 20, 50, 100, 200, 787}`. +3. For each group size: + - sample multiple subsets if possible + - decode each averaged map + - compute mean and variance of Recall@K +4. Plot decoding performance versus group size. +5. Compare the trend against the qualitative covariate-shift explanation to determine whether noise alone explains the gap. + +### Deliverables + +- Group-size sensitivity figure. +- Quantitative statement about SNR versus text/image mismatch. + +## Task 11: ROI Follow-Up And Striatum Hypothesis Example + +### Goal + +Use the existing ROI decoding setup to produce a concrete hypothesis-generation example centered on striatum-language associations. + +### Non-goals + +- Claiming novel neurobiological discovery solely from the decoder. +- Replacing literature validation with decoding outputs. + +### Concrete Implementation Plan + +1. Repair and rerun `jobs/decoding_seeds.py`. +2. Export top task, concept, and domain predictions for each ROI. +3. For striatum specifically, record: + - top semantic/language tasks + - top concepts + - top domain probabilities +4. Package the striatum results into a concise table suitable for the manuscript and a separate table for supplement use. +5. Hand off the exact predicted labels and scores for downstream literature validation and meta-analytic discussion. + +### Deliverables + +- Updated ROI decoding outputs. +- Striatum-specific results table for the hypothesis-generation example. + +## Task 12: Optional NSD Pilot + +### Goal + +Probe the boundary of model generalization on a natural-scene perception dataset. + +### Non-goals + +- Treating NSD as a core benchmark. +- Claiming scene-level semantic decoding if the ontology is mismatched. + +### Concrete Implementation Plan + +1. Only proceed if NSD maps are already available in a usable format. +2. Define a small pilot set of maps expected to load strongly on visual perception. +3. Run decoding with the existing perception-related ontology terms. +4. Report the pilot qualitatively and quantitatively as a boundary case, not a headline result. + +### Deliverables + +- Optional appendix-level NSD pilot result. + +## Suggested New Files + +- `jobs/decoding_cnp.py` +- `jobs/per_term_eval.py` +- `jobs/null_baseline.py` +- `jobs/embedding_geometry.py` +- `jobs/snr_sweep.py` +- `data/.../ground_truth_*.json` or `.csv` files for HCP, IBC, and CNP mappings +- `results/.../tables/` and `results/.../figures/` subdirectories for manuscript-ready outputs + +## Minimal Milestone Cut + +If time is constrained, the minimum credible set for the revision is: + +1. Task 0: Pipeline hardening. +2. Task 1: HCP benchmark reproduction. +3. Task 2: IBC cross-dataset analysis. +4. Task 3: CNP cross-dataset analysis. +5. Task 5: Per-term decodability. +6. Task 6: Null/chance baseline analysis. +7. Task 8: Reduced versus full ontology comparison. +8. Task 10: SNR/group-size analysis. + +The embedding geometry, striatum deep dive, emotion expansion, and NSD pilot are valuable but can be treated as secondary if deadlines are tight. diff --git a/review_response.md b/review_response.md new file mode 100644 index 0000000..346b934 --- /dev/null +++ b/review_response.md @@ -0,0 +1,707 @@ +# **Response to Reviewers: Manuscript NCOMMS-25-63915-T** + +We thank the Reviewers for their helpful and constructive comments regarding our manuscript. + +The Reviewers brought up several valuable critiques of the manuscript, which we have addressed in the revised manuscript. Below, these concerns are addressed on a point-by-point basis. Revisions to the manuscript are shown in red font. + +# **Reviewer 1 Comments** + +Summary: +This paper introduces NiCLIP, a neuroimaging contrastive language-image pretraining model, aiming to predict cognitive tasks, concepts, and domains from brain activation patterns. The authors trained NiCLIP on over 23,000 neuroscientific articles, leveraging large language models (LLMs) and deep contrastive learning to establish text-to-brain associations. Key findings indicate that fine-tuned LLMs and the use of full-text articles with a curated cognitive ontology optimize NiCLIP's predictive accuracy. Evaluations, particularly with group-level activation maps from the Human Connectome Project, show NiCLIP's capability to accurately predict cognitive tasks across various domains (e.g., emotion, language, motor) and characterize the functional roles of specific brain regions. However, the model exhibits limitations with noisy subject-level activation maps. While presented as a significant advancement for hypothesis generation in neuroimaging, my assessment is that despite the extensive literature data and thorough experimental validation demonstrating the model's effectiveness in reverse-decoding brain states, the article's innovation is insufficient. The comparison with baseline methods appears to be not entirely fair, and the discussion of future prospects lacks both depth and convincing evidence. + +We appreciate Reviewer \#1's comments and their recognition that NiCLIP demonstrates "extensive literature data and thorough experimental validation" for decoding brain maps. We address below the concerns regarding innovation, experimental comparisons, and application prospects. + +## **Reviewer 1 Comment 1** + +1\. Lack of Innovation +The training pipeline for this paper is nearly identical to that of NeuroConText, with only minor differences in the inference stage. While NeuroConText indexes corresponding brain activity from text, this paper indexes text descriptions from a knowledge graph based on brain activation. The authors highlight the "lack of validation for reverse inference (brain imaging to text) in functional decoding" as a primary shortcoming of previous work. However, once a brain-text contrastive learning model is trained, bidirectional decoding (both brain-to-text and text-to-brain) becomes feasible. Therefore, this cannot be considered a distinctive contribution of the current work. + +### **Response** + +We thank the reviewer for this observation. We acknowledge that both NiCLIP and NeuroConText share a common foundation: using a CLIP-based contrastive learning framework to align brain activation maps with text in a shared latent space. We stated in the Introduction and Methods section about such similarities: + +“The CLIP model architecture (Fig. 4A) adheres to the identical settings employed in the NeuroConText framework (Meudec et al., 2024).” + +While the reviewer is correct that a trained CLIP model in principle enables bidirectional retrieval, we emphasize that \*\*the contribution of NiCLIP is not simply "reversing the arrow" of inference\*\*. Rather, the novelty lies in the combined integration of three specific advances that together constitute a qualitatively different system for functional decoding: + +1\. Ontology-driven Bayesian decoding framework. NeuroConText operates purely data-driven: it retrieves the nearest text embedding in latent space. In contrast, NiCLIP introduces a structured Bayesian decoding pipeline (Section 5.3) where the CLIP similarity is used as the likelihood P(A|T) within Bayes' theorem, combined with literature-derived priors P(T), to compute posterior probabilities P(T|A) for tasks. This posterior is then propagated through a cognitive ontology using the noisy-OR model to predict concepts P(C|A) and domains P(D|A). This hierarchical, ontology-grounded inference is not a feature of NeuroConText and represents a fundamentally different approach to interpreting brain activation patterns. + +2\. Domain-specific LLMs (BrainGPT). NiCLIP is the first to systematically evaluate and demonstrate the benefit of neuroscience-specific fine-tuned LLMs (BrainGPT-7B-v0.1 and v0.2) versus their base models (Llama-2 and Mistral) for text-to-brain association. Our results show that BrainGPT-7B-v0.2 provides superior text-to-brain associations compared to general-purpose LLMs (Table 1). + +3\. Integration of the Cognitive Atlas ontology for structured predictions. Unlike NeuroConText, which outputs free-text retrievals (i.e., without guardrails), NiCLIP maps predictions to a curated cognitive vocabulary with task-to-concept and concept-to-domain mappings, enabling structured, interpretable outputs at multiple granularity levels. We demonstrated that the choice of ontology significantly impacts decoding accuracy (Table 2). + +We will revise the introduction (Lines 33–34 of the current manuscript) to more clearly delineate the shared training pipeline and articulate these distinct contributions: + +"While NiCLIP shares the CLIP-based contrastive training framework with NeuroConText (Meudec et al., 2024), it advances beyond text-to-brain retrieval by introducing: (1) a Bayesian decoding framework that integrates CLIP-derived likelihoods with literature priors and structured ontologies for hierarchical reverse inference, (2) the systematic evaluation of neuroscience-specific LLMs for improved text-brain alignment, and (3) the first formal validation of ontology-driven functional decoding in a contrastive learning framework." + +## **Reviewer 1 Comment 2** + +Experimental Issues + +## **Reviewer 1 Comment 2.1** + +Why was NeuroConText not included as a baseline for comparison in this study?. + +### **Response** + +NeuroConText and NiCLIP address fundamentally different directions of the brain-text mapping problem. NeuroConText is designed to predict brain activation maps from text (text → brain), whereas NiCLIP is designed to predict text from brain activation maps (brain → text). This asymmetry means that a direct decoding comparison is not feasible, as the two models solve different tasks. + +However, a meaningful comparison is possible at the level of text-to-brain association (i.e., the quality of the learned brain-text embedding space). In Table 1, we present results comparing our best-performing setting using BrainGPT embeddings against the best-performing base model setting from NeuroConText. This comparison demonstrates that NiCLIP achieves superior text-to-brain association performance. + +We will expand this discussion in the revised manuscript to clarify the complementary nature of the two approaches: NeuroConText is capable of generating brain maps from arbitrary text descriptions, while NiCLIP provides structured Bayesian decoding from brain maps to cognitive tasks, concepts, and domains grounded in a cognitive ontology. + +**Section:** + **Page** + +On average,. + + + +## **Reviewer 1 Comment 2.2** + +Figures 2 and 3 only present titles for the descriptions of each brain region. Given that the model utilized full-text information during training, why are more detailed descriptions not provided?. + +### **Response** + +NiCLIP's decoding framework deliberately predicts from a structured vocabulary of Cognitive Atlas tasks rather than generating open-ended text. This was a purposeful design choice to set guardrails on the model's predictions and improve interpretability. While open-text prediction is an interesting direction, unconstrained text generation from brain maps is prone to hallucination, making the resulting predictions difficult to interpret and validate. By connecting predictions to a curated cognitive ontology, each decoded label carries a well-defined meaning with associated concept mappings and domain classifications. + +The predictions displayed in Figures 2 and 3 are Cognitive Atlas task names because the decoding stage computes posterior probabilities for each task in the vocabulary, where the task embedding is the weighted combination of its name and definition (Eq. in Section 5.3, λ \= 0.5). While training uses full-text articles, the decoding projects task names and definitions into the shared latent space to compute similarities. + +To address the reviewer's concern, in the revision we will add a supplementary table listing all predicted tasks with their full Cognitive Atlas definitions for each decoded map, allowing readers to appreciate the richer semantic content driving the predictions. + + **Section: Results** + **Page 26** + +Together. + +## **Reviewer 1 Comment 2.3** + +The paper states that "the most commonly used 'association-based decoders' are not based on formal models, cannot identify underlying structures related to specific cognitive processes, and lack sensitivity to unseen brain patterns." However, I found no evidence in this manuscript to support the claim that NiCLIP can interpret unseen brain patterns. + +### **Response** + +We acknowledge that this claim requires clearer evidence. The statement about "unseen brain patterns" refers to NiCLIP's capacity to generalize beyond the specific brain maps present in its training data, which are modeled activation maps from PubMed coordinates. Specifically: + +1\. Generalization to unseen map types. The HCP group-level statistical maps used for evaluation (Section 2.3, Figures 2 and 3\) are whole-brain t-statistic maps that are structurally different from the MKDA-modeled coordinate maps used for training. The fact that NiCLIP correctly predicts tasks from these unseen map types (e.g., Motor task predicted at 48.3%, Relational processing at 54.5%) demonstrates generalization. + +2\. ROI-based decoding. The six ROI maps (Figure 3\) represent a qualitatively different input type, binary/sparse masks rather than continuous statistical maps, yet NiCLIP produces highly selective predictions (e.g., rTPJ → social cognition at 98.5%). + +3\. Cross-domain generalization. NiCLIP can predict tasks it was never explicitly trained on, since the vocabulary embeddings are computed from Cognitive Atlas definitions, not from training data labels. Any task with a name and definition in the ontology can be included in the prediction vocabulary. + +In the revision, we will: +\- Add a dedicated subsection validating generalization by evaluating NiCLIP on additional datasets beyond HCP (see response to point 2(5) below). +\- Clarify the wording to: \*"NiCLIP's self-supervised CLIP architecture learns a shared latent space from text-image pairs, enabling it to generalize to brain maps and cognitive tasks not present in the training data, including different map types (statistical maps vs. modeled coordinates) and novel vocabulary terms."\* + +**Section: Introduction** + **Page 2** + +Recent . + +## **Reviewer 1 Comment 2.4** + +In Figure 2, the model is shown to identify the current task (e.g., emotion) based on brain activation states. I am curious whether the model can further identify specific emotions, such as amusement or disgust. + +### **Response** + +This is an excellent question. The current Cognitive Atlas ontology includes some specific emotion tasks (e.g., "emotion regulation task," "emotional face recognition task") and concepts (e.g., "fear," "anger," "disgust," "happiness"), but the granularity of emotion-specific predictions depends on the vocabulary and training data coverage. + +For the revision, we will: + +1\. Test NiCLIP on emotion-specific contrasts. The HCP emotion processing task includes a "Faces vs. Shapes" contrast that is relatively coarse. We will additionally evaluate NiCLIP on more fine-grained emotion contrasts from other datasets (e.g., the CNP dataset, which we have already performed decoding, or the IBC dataset, which includes multiple emotion-related contrasts). + +2\. Expand the Cognitive Atlas vocabulary to include specific emotion concepts and evaluate whether NiCLIP can discriminate between specific emotional states when presented with activation maps from paradigms that target individual emotions. + +3\. Discuss the granularity limitation in the revised manuscript, noting that NiCLIP's ability to distinguish fine-grained cognitive states depends on (a) the specificity of the ontology vocabulary and (b) the diversity of training data covering such distinctions. + +**Section: Materials and Methods** + **Page 8** + +Following seg. + +## **Reviewer 1 Comment 2.5** + +Can this model be generalized across datasets during inference? For example, would it operate effectively on the NSD dataset?. + +### **Response** + +We agree that demonstrating cross-dataset generalization is essential, and we have now completed this analysis on two additional NeuroVault datasets using the same reduced Cognitive Atlas evaluation interface as for HCP. + +1\. Individual Brain Charting (IBC) dataset: We decoded 1,608 IBC maps spanning 11 task families. In the reduced ontology benchmark, NiCLIP achieved task Recall@4 \= 15.67%, concept Recall@4 \= 10.38%, and domain Recall@2 \= 33.74% for the combined task-name+definition vocabulary. Using task names alone, NiCLIP achieved task Recall@4 \= 8.46%, concept Recall@4 \= 6.30%, and domain Recall@2 \= 40.55%. For comparison, on the task-only benchmark the association-based baselines scored 16.36% (GCLDA) and 7.34% (Neurosynth) task Recall@4. + +2\. Consortium for Neuropsychiatric Phenomics (CNP) dataset: We evaluated 130 CNP maps in the reduced ontology benchmark, spanning BART, PAM retrieval, SCAP, and Stop Signal tasks. In this harder transfer setting, NiCLIP achieved task Recall@4 \= 0.77%, concept Recall@4 \= 2.56%, and domain Recall@2 \= 20.77% for the combined vocabulary. With task names alone, performance was task Recall@4 \= 0.77%, concept Recall@4 \= 2.05%, and domain Recall@2 \= 6.54%. The task-only baselines scored 1.54% (GCLDA) and 3.08% (Neurosynth) task Recall@4. Task Switching remains available in the full ontology mapping, but is not part of the reduced Menuet et al. task set and was therefore excluded from the reduced evaluation. + +These cross-dataset results show that NiCLIP generalizes substantially better to IBC than to CNP. We interpret this as evidence that transfer depends strongly on ontology coverage and on how closely the external task paradigms resemble the task-evoked activation structure represented in the training corpus. + +Regarding the Natural Scenes Dataset (NSD), we note that NSD is a visual perception dataset where subjects view naturalistic images. NiCLIP was trained on CBMA data from the fMRI literature, which primarily covers task-based paradigms. The NSD paradigm (passive viewing of natural scenes) is substantially different from the activation patterns in our training data. Applying NiCLIP to NSD would likely reveal the model's capacity to identify visual perception-related concepts but may underperform for scene-specific content, since the training data does not include naturalistic viewing paradigms at scale. We will add this as a discussion point and, if feasible, include a preliminary evaluation. + +In the revision, we will include a new cross-dataset generalization subsection reporting these IBC and CNP results and discussing the boundaries of transfer. + +**Section: Introduction** + **Pages 2-3** + +Using + +## **Reviewer 1 Comment 3** + +Insufficient Credibility Regarding Future Application Prospects + +## **Reviewer 1 Comment 3.1** + +There is an absence of a user-friendly, one-click web interface for researchers. + +### **Response** + +We appreciate this practical suggestion. While building a full web interface is outside the scope of this manuscript, we are committed to making NiCLIP accessible to the community. For the revision, we will: + +1\. Provide a pip-installable Python package (\`braindec\`) with a simple API that allows decoding in three lines of code: + \`\`\`python + from braindec.predict import image\_to\_labels + results \= image\_to\_labels("my\_brain\_map.nii.gz", ...) + \`\`\` + +2\. Include a Jupyter notebook tutorial demonstrating a complete decoding workflow from input brain image to task/concept/domain predictions, which can run on Google Colab. + +3\. Discuss plans for a web interface in the revised manuscript, noting that a web tool in Neurosynth Compose is planned for future development and will be hosted at a publicly accessible URL. + +4\. Release all trained model weights and vocabulary files on the Open Science Framework (OSF), enabling any researcher to run NiCLIP predictions on their own data. + +## **Reviewer 1 Comment 3.2** + +The authors state that "this paper provides researchers with a powerful tool for hypothesis generation and scientific discovery." I believe the author team should empirically demonstrate this claim by exploring or validating a specific hypothesis within the paper, as I currently do not perceive the utility of NiCLIP in hypothesis generation and scientific discovery. + +### **Response** + +This is a fair and constructive point. In the revision, we will include a concrete example of hypothesis generation using NiCLIP. Specifically: + +\*\*Proposed example: Functional characterization of the striatum.\*\* + +In our ROI analysis (Figure 3), NiCLIP predicted a language specialization for the striatum, which is a notable and perhaps surprising finding. We will leverage this prediction as a hypothesis-generation example: + +1\. NiCLIP prediction: The striatum shows strong association with semantic processing tasks (52.1%) and language-related concepts (52.2%), with Language (53.6%) as the dominant domain. + +2\. Literature validation: We will conduct a targeted literature review and meta-analytic verification. The striatum has indeed been implicated in language functions, including lexical-semantic processing (Crosson et al., 2007; Crinion et al., 2006), bilingual language control (Abutalebi & Green, 2016), and word learning (Shohamy & Adcock, 2010). These findings corroborate NiCLIP's prediction. + +3\. Demonstration value: We will present this as a complete cycle: NiCLIP generates a non-obvious prediction → the prediction is validated against existing literature → this demonstrates the model's utility for generating testable hypotheses from brain maps. + +Additionally, we will discuss how NiCLIP could be used to characterize under-studied brain regions or novel parcellations where no prior functional annotation exists, directly serving the hypothesis-generation use case. + +# **Reviewer 2 Comments** + +The paper NCOMMS-25-63915-T entitled "NiCLIP: Neuroimaging contrastive language-image pretraining model for predicting text from brain activation images" presents a framework that leverages existing fMRI literature to learn associations between brain locations and cognitive concepts. This framework uses an ontology to define these concepts and LLMs to extract them from text. The spatial aspect of the model, which involves mapping reported MNI coordinates to brain maps, has been addressed using the PubGet framework. The remaining weak point was extracting the concepts from each individual publication in order to build the associations. The novelty here lies in using LLMs for this task. Furthermore, the concepts are organised within an ontology derived from the Cognitive Atlas. The paper places an important focus on validation, particularly decoding, where images from the Human Connectome Project dataset are decoded into task-specific and contrast-specific concepts. + +I enjoyed reading the paper because it offers a significant opportunity to advance coordinate-based meta-analyses in neuroimaging and make a valuable contribution to improving this field with the use of advanced AI technologies: LLMs and contrastive learning. The effort put into validation is particularly noteworthy, as it comprises both map- and ROI-level experiments. +I would certainly like to see it published in a high-profile venue such as Nature Communications. + +However, there is room for improvement that could be addressed in a revision. + +We sincerely appreciate Reviewer \#2's enthusiasm for the paper and their expert feedback. The reviewer highlights that NiCLIP offers "a significant opportunity to advance coordinate-based meta-analyses in neuroimaging" and notes the "noteworthy" validation effort. Below we address each concern in detail. + +## **Reviewer 2 Comment 1** + +Position with respect to the state of the art + +While the paper cites a wide range of state-of-the-art contributions, it is uneven in its coverage. For example, several papers cited in the discussion (Mensch et al., 2021, 2017; Menuet et al., 2022; Varoquaux et al., 2018\) are omitted from the introduction, despite their conceptual contributions being relevant to NiClip (emphasis on ontology-based analysis and the large-scale use of the Cognitive Atlas and NeuroVault, emphasis on decoding as the most principled validation approach). In my opinion, the contribution of NiClip is clear and it should not be an issue to acknowledge that some of the core intuitions of this paper have been introduced in previous publications, since these did not contribute to any improvement in CBMA (instead, they were about IBMA). The same applies to (Oudyk et al., 2025). + +Moreover, reading the methods section shows that the paper followed the recent contribution by Meudec et al. (2024) quite closely. However, the technical similarity with this prior contribution is not clearly outlined in the introduction. As an online version of the paper is available with some code (https://www.biorxiv.org/content/10.1101/2025.05.23.655707v1.full), I would expect it to perform a formal comparison between the two approaches. + +### **Response** + +We agree and will substantially revise the introduction to provide a more balanced and comprehensive coverage of the state of the art. Specifically: + +\- Mensch et al. (2017, 2021): We will acknowledge their pioneering work on learning neural representations across multiple fMRI studies and their emphasis on supervised decoding as a principled validation approach. Their work demonstrated the value of large-scale aggregation for brain decoding, which is a conceptual foundation of NiCLIP. + +\- Menuet et al. (2022): We will discuss their comprehensive decoding framework using NeuroVault statistical maps and Cognitive Atlas, acknowledging that they introduced (a) the reduced Cognitive Atlas ontology that we also employ, (b) per-term accuracy evaluation for decoding, and (c) the use of image-based meta-analysis (IBMA) for decoding. We will position CBMA vs. IBMA as complementary approaches (see response to point 2 below). + +\- Varoquaux et al. (2018): We will credit their work on atlases of cognition and their emphasis on principled decoding evaluation. + +\- Oudyk et al. (2025): We will incorporate their recent overview of neuroimaging meta-analyses in the introduction. + +\- Meudec et al. (2024): We will add a transparent discussion of the architectural similarities with NeuroConText, acknowledging that NiCLIP follows the CLIP training framework introduced therein, while clearly delineating the novel contributions (Bayesian decoding, domain-specific LLMs, ontology integration). + +The revised introduction will include a paragraph explicitly contrasting IBMA and CBMA approaches: + +"Image-based meta-analysis (IBMA) approaches, which leverage whole-brain statistical maps from repositories like NeuroVault, have demonstrated strong decoding performance by preserving rich spatial information (Mensch et al., 2017, 2021; Menuet et al., 2022; Varoquaux et al., 2018). However, IBMA approaches face limitations in data coverage: despite community efforts, most neuroimaging studies do not share their statistical maps, resulting in sparse and unevenly annotated repositories (Peraza et al., 2025; Salo et al., 2023). In contrast, coordinate-based meta-analysis (CBMA) approaches benefit from the much larger coverage of databases like Neurosynth and BrainMap, which encompass \>30,000 publications. NiCLIP builds on the CLIP framework introduced by NeuroConText (Meudec et al., 2024\) but targets CBMA-based functional decoding, trading the spatial richness of statistical maps for the broader domain coverage of coordinate databases.” +. + +## **Reviewer 2 Comment 2** + +The authors have chosen to frame the decoding part in a Bayesian way, which constrains the interpretation of their results and the type of experiments they can conduct to validate the model. +Ideally, I would like to see a per-term accuracy score for as many terms as possible, as in Menuet et al. (2022). Currently, we don't even have a rough idea of how many concepts, or which ones, can be properly decoded from data or sets of reported locations. + +### **Response** + +We agree this is important for understanding which cognitive terms NiCLIP can reliably decode, and we have now added a per-term analysis with a within-dataset permutation null baseline. + +For the reduced cross-dataset benchmark: + +1\. IBC: In the combined vocabulary setting, 6/11 task terms, 17/30 concepts, and 6/9 domains were above chance. In the names-only setting, 4/11 task terms, 7/30 concepts, and 2/9 domains were above chance. + +2\. CNP: In the combined vocabulary setting, 0/4 task terms and 0/5 domains were above chance, while 1/12 concepts ("recall") exceeded the permutation baseline. In the names-only setting, no task, concept, or domain terms were above chance. + +3\. The above-chance IBC task terms were concentrated in social, emotional, motor, and spatial paradigms rather than being uniformly distributed across the ontology. Examples include emotion processing fMRI task paradigm, emotional localizer fMRI task paradigm, motor fMRI task paradigm, social cognition (theory of mind) fMRI task paradigm, Social localizer fMRI task paradigm, and spatial localizer fMRI task paradigm. + +We will present these per-term results as a heatmap/bar chart in the revision and discuss the term-specific heterogeneity explicitly, rather than relying only on averaged benchmark scores. + +## **Reviewer 2 Comment 2.1** + +The authors have chosen to base their decoding validation on the HCP dataset. I have several issues with this: + +\* First, I did not understand how the decoding was carried out at the task level. There are 22 maps in Collection 457\. How do we go from there to task-level maps? Also, I did not understand how the task-to-concept mapping was done in the method section. Is it simply the Cognitive Atlas, or an improved version of it? Can this be made explicit? + +### **Response** + +We apologize for the lack of clarity. We will revise the Methods section to explicitly describe the mapping procedure: + +Collection 457 in NeuroVault contains 22 contrast maps from 7 HCP task domains. For task-level decoding, we selected one representative contrast per task domain that best captures the core cognitive process: + +Emotion: Faces vs. Shapes +Gambling: Reward vs. Baseline +Language: Story vs. Math +Motor: Average (all movements) +Relational: Relational vs. Match +Social: TOM vs. Random +Working Memory: 2-Back vs. 0-Back + +The task-to-concept mapping relies on the Cognitive Atlas. Each HCP task (e.g., "Emotion processing fMRI task paradigm") has a corresponding entry in the Cognitive Atlas with associated concepts and domains. We used the reduced Cognitive Atlas ontology (derived from Menuet et al., 2022\) for this mapping. We will make this explicit in a revised Methods subsection and provide the full mapping table as supplementary material. + +**Section: Results** + **Page 20** + +Overall, + +## **Reviewer 2 Comment 2.2** + +\* The results of the HCP group-level map decoding are disappointing, whether expressed as decoding probabilities or in Fig. 2\. I believe any expert neuroscientist could easily match group maps from each HCP contrast to their label without error. The failure to do so with NiClip suggests that there is still room for improvement in the framework. In my opinion, the main issue is that the authors extract LLM latents from short task and contrast definitions that are not homogeneous with the long texts used to train the model, resulting in a significant shift in covariates between neuroimaging publications and task/contrast descriptions. This phenomenon is described e.g. in https://www.biorxiv.org/content/10.1101/2025.05.23.655707v1.full. + +### **Response** + +We acknowledge that the HCP group-level maps represent a relatively "easy" benchmark where an expert could likely achieve perfect accuracy by visual inspection. However, we view the HCP evaluation differently: + +1\. The goal is not to outperform human experts. A human expert cannot manually classify a high volume of brain images against a comprehensive cognitive ontology; NiCLIP is designed to provide automated, scalable decoding of any brain map against a structured vocabulary, enabling functional interpretation at a scale that manual classification cannot achieve. The HCP serves as a proof-of-concept validation where ground truth is known. + +2\. NiCLIP's Recall@4 of 62.86% for tasks (Table 2\) should be interpreted in context: the vocabulary contains hundreds of possible tasks, so even Recall@4 means identifying the correct task among the top 4 out of hundreds of candidates. For domains, NiCLIP achieves 90.48% Recall@2. + +3\. The covariate shift issue. We agree with the reviewer that a key limitation is the mismatch between long training texts and short task descriptions used during inference. As described in the expanded NeuroConText preprint (Meudec et al., 2025), there is a distributional shift between LLM embeddings of long publication texts and short task/concept definitions. In NiCLIP, this is partially mitigated by combining the task name and definition embeddings (Eq. in Section 5.3), but we acknowledge that further work on text augmentation strategies (e.g., LLM-based expansion of short definitions to article-like formats, as proposed in Meudec et al., 2025\) could improve performance. We will add this discussion and propose augmentation as a concrete avenue for improvement. + +4\. Improvement over baselines. Despite the acknowledged room for improvement, NiCLIP substantially outperforms both Neurosynth and GC-LDA baselines (\>40% improvement in Recall scores, Table 2), demonstrating that the CLIP-based framework provides meaningful advances for functional decoding. + +## **Reviewer 2 Comment 2.3** + +\* The results on individual HCP data are poor, but this is likely due to the same inadequate matching being exacerbated by the lower SNR and additional variability in the individual data. + +### **Response** + +We agree that subject-level decoding performance is a clear limitation. As discussed in Section 2.4.3, subject-level maps exhibit high variability and noise compared to group-level maps. However, we note that the observed performance (Recall@4 of 38.19% for tasks, Recall@2 of 52.01% for domains) still exceeds chance, suggesting that NiCLIP captures some individual-level signal. + +In the revision, we will: +1\. Conduct an SNR sensitivity analysis by varying group sizes (n=5, 10, 20, 50, 100, all 787 subjects) and reporting decoding accuracy as a function of effective SNR (see also response to the reviewer's related minor point). +2\. Discuss the covariate shift issue as a likely contributing factor, as suggested by the reviewer. +3\. Propose training data augmentation (adding subject-level maps with noise to the training set) as a concrete path to improvement. + +## **Reviewer 2 Comment 2.4** + +\* Finally, I would like the authors to compare their approach with the IBMA decoding of Menuet et al. (2022). It would be extremely interesting to know which decoding approach works best. I actually think that IBMA outperforms... + +### **Response** + +We appreciate the interest in this comparison. However, we note that a direct comparison on the HCP benchmark would not be fair. Menuet et al. (2022) trained their IBMA decoder on NeuroVault statistical maps, which include HCP-derived images in their training set. In contrast, NiCLIP was not trained on HCP data, we use HCP exclusively as a held-out evaluation benchmark. Comparing the two methods on HCP would therefore favor the IBMA decoder due to data leakage. + +That said, a key advantage of NiCLIP's architecture is its scalability. Because NiCLIP is trained on coordinate-based data from \~23,865 publications, it already covers a far broader range of cognitive tasks and domains than what is available in NeuroVault. Moreover, NiCLIP's training set can be easily augmented with image-based data (statistical maps) in the future, combining the breadth of coordinate databases with the richer spatial information of full brain maps. The IBMA approach, by contrast, remains constrained by the limited and inconsistently annotated collection of maps in NeuroVault. + +We will discuss this tradeoff in the revised manuscript, positioning the two approaches as complementary: IBMA provides richer per-map information when high-quality statistical maps are available, while NiCLIP offers broader coverage and a scalable framework that can incorporate both coordinate and image data. + +## **Reviewer 2 Comment 3** + +Additionally, I would like to see validation on datasets other than HCP. Some of the CogAt entries seem to have been designed for HCP in particular, which creates a kind of circularity. Alternative, comprehensive datasets have been shared on NeuroVault, for example. + +### **Response** + +This is an important methodological concern. We address the potential circularity and cross-dataset validation: + +Regarding circularity: The Cognitive Atlas is a community-driven ontology developed independently of HCP. While some HCP tasks have corresponding entries in the Cognitive Atlas, the ontology was not designed specifically for HCP. The reduced Cognitive Atlas vocabulary we use (from Menuet et al., 2022\) contains tasks from diverse sources, not exclusively HCP. Furthermore, NiCLIP's training data consists of \~23,865 PubMed articles spanning the entire fMRI literature, not HCP-specific publications. + +Cross-dataset validation: As noted in our response to Reviewer 1 (point 2(5)), we now include decoding results on: +\- IBC dataset (Individual Brain Charting): 1,608 reduced-ontology maps spanning 11 task families, with task Recall@4 \= 15.67% and domain Recall@2 \= 33.74% in the combined NiCLIP setting +\- CNP dataset (Consortium for Neuropsychiatric Phenomics): 130 reduced-ontology maps spanning BART, PAM retrieval, SCAP, and Stop Signal, with task Recall@4 \= 0.77% and domain Recall@2 \= 20.77% in the combined NiCLIP setting + +We now have runnable decoding pipelines for both datasets (\`jobs/decoding\_ibc.py\`, \`jobs/decoding\_cnp.py\`) and will present these results in a new Results subsection. + + +**Section: Materials and Methods** + **Page 16** + +In Fig. S5, we re + +## **Reviewer 2 Comment 4** + +The authors introduce a reduced version of the Cognitive Atlas, but it is difficult to ascertain how it differs from the original. Providing more details on the motivations and actual differences is important. I am actually puzzled by the necessity to reduce Cognitive Atlas: what do the authors mean? What did they actually do? + +In the discussion, the authors state, 'We demonstrated that a reduced and curated representation of the Cognitive Atlas tasks, combined \[...\] +with a more robust and comprehensive mapping of concepts, outperforms the original Cognitive Atlas ontology." However, this is not clear from their experimental results. + +### **Response** + +We will substantially expand the description of the reduced Cognitive Atlas in the revised Methods. Specifically: + +Motivation: The original Cognitive Atlas contains 851 tasks and 912 concepts, many with incomplete definitions, missing task-concept mappings, or inconsistent annotations. For instance, some popular tasks like "motor fMRI task paradigm" are only linked to a few concepts (e.g., only "working memory"), missing obvious associations (e.g., "movement," "motor control"). This incompleteness propagates to the concept and domain predictions. + +What was done: Following Menuet et al. (2022), we used a curated subset that: +\- Retains \~100 of the most commonly used fMRI tasks based on their prevalence in the literature. +\- Manually enriches task-to-concept mappings to ensure that each task is associated with all relevant cognitive concepts. For example, the motor fMRI task paradigm is linked to "movement," "motor control," "motor learning," and other motor-related concepts in the reduced version. +\- Adds manual concept-to-domain mappings for concepts missing domain annotations (16 manual mappings, listed in our code at \`braindec/cogatlas.py\`, Lines 26–43). +\- Maintains the original concept-to-domain structure using the 10 cognitive process categories from Cognitive Atlas. + +Evidence of improvement: Table 2 demonstrates that the reduced ontology consistently outperforms the original across all models and metrics. For example, with BrainGPT-7B-v0.2 (body, name+definition), Recall@4 for tasks increases from near-zero with the full Cognitive Atlas to 62.86% with the reduced version. + +We will add a supplementary table comparing the full vs. reduced Cognitive Atlas (number of tasks, concepts, task-concept edges, concept-domain edges) and provide the complete mapping file as supplementary material. + +**Section: Abstract** + **Page 1** + +Finally,. + +## **Reviewer 2 Comment 5** + +## Related to point 2, I think the authors should provide visualisations of the geometry of the embeddings so that their structure can be checked.. + +### **Response** + +We agree that embedding visualization would strengthen the paper. In the revision, we will add: + +1\. UMAP/t-SNE visualization of the shared latent space showing both text embeddings (Cognitive Atlas tasks, colored by domain) and image embeddings (HCP maps, colored by task), demonstrating the alignment quality in the learned space. + +2\. Embedding structure analysis showing: + \- Clustering of task embeddings by cognitive domain + \- Distance between HCP image embeddings and their corresponding task text embeddings + \- Whether semantically related tasks cluster together (e.g., motor tasks, language tasks) + +3\. Comparison of embedding geometries across different LLMs (BrainGPT vs. Mistral vs. Llama), providing insight into why BrainGPT achieves better performance. + +This will be presented as a new supplementary figure. + +**Section: Introduction** + **Page 2** + +The main goal of parcellating functional connectomes is to reduce high-dimensional connectivity space into a. + +## **Reviewer 2 Comment 6** + +## Introduction: I suggest discussing the relative merits of IBMA and CBMA more explicitly. + +### **Response** + +Agreed. We will add a dedicated paragraph in the Introduction comparing IBMA and CBMA for functional decoding (see revised text in response to Major Point 1 above). + +## **Reviewer 2 Comment 7** + +## In the abstract, the authors emphasize the difference between BrainGPT and other LLMs, but the effect is actually modest, as can be seen in Table 1\. + +### **Response** + +We acknowledge this. The differences are statistically modest (e.g., Recall@10: 33.56 vs. 33.36 for BrainGPT-7B-v0.2 vs. Mistral). We will revise the abstract to temper this claim: + +\> "We demonstrated that domain-specific fine-tuned LLMs (e.g., BrainGPT) provide modestly improved text-to-brain associations compared to their base counterparts, with more pronounced benefits observed in downstream decoding tasks." + +We note that the BrainGPT advantage becomes more pronounced in the decoding evaluation (Table 2), where the difference between BrainGPT and base LLMs is larger (e.g., Task Recall@4: 62.86% vs. 55% for BrainGPT-v0.2 vs. v0.1). + +## **Reviewer 2 Comment 8** + +## Why didn't the authors use the DiFuMo1024 dictionary instead of the DiFuMo512 dictionary? + +### **Response** + +We chose DiFuMo512 following the NeuroConText framework (Meudec et al., 2024, which also used DiFuMo512) to ensure a fair comparison and because it provides a balance between spatial resolution and model complexity. DiFuMo1024 doubles the image embedding dimension, which would increase training time and may require architectural adjustments. The NeuroConText ablation studies (Meudec et al., 2025\) showed that DiFuMo512 provided a meaningful improvement over DiFuMo256, but we are not aware of published evidence that DiFuMo1024 provides substantial further improvement for text-brain association tasks. We will add this justification to the Methods and note DiFuMo1024 as a future direction. + +## **Reviewer 2 Comment 9** + +## Could the authors provide more comprehensive captions for the tables in the paper? Currently, it is difficult to identify what the numbers in the tables actually mean. + +### **Response** + +Agreed. We will expand all table captions to include: +\- Clear definition of each metric +\- Description of what rows and columns represent +\- Explanation of how to interpret the numbers (e.g., "higher is better") +\- Reference to the specific sections where the experimental setup is described + +## **Reviewer 2 Comment 10** + +## Regarding Fig. 2, I don't understand how the authors end up with one map per task since tasks generally have more than one possible contrast and sometimes several contrasts of interest (e.g., MOTOR, WM). + +### **Response** + +See our response to Major Point 2 above regarding the task-level mapping. We selected one representative contrast per task domain. We will make this explicit in the figure caption and Methods section, and provide the full mapping in supplementary material. + +## **Reviewer 2 Comment 11** + +## I am quite surprised by the finding that the striatum shows a specialization for language. Image decoding results are quite underwhelming overall. + +### **Response** + +We share the reviewer's initial surprise, but note that the striatum's role in language has been documented in the literature. The caudate nucleus and putamen have been implicated in: +\- Lexical-semantic processing (Crosson et al., 2007\) +\- Bilingual language control (Abutalebi & Green, 2016; Crinion et al., 2006\) +\- Speech production (Bohland & Guenther, 2006\) +\- Syntactic processing (Ullman, 2004\) + +The striatum meta-analytic parcellation we used (Liu et al., 2020\) encompasses both reward-related and language-related subregions, which likely contributes to this prediction. We will add this discussion to the revised Results section and use it as the hypothesis-generation example (see response to Reviewer 1, point 3(2)). + +## **Reviewer 2 Comment 12** + +## The statement "The current trained model should not be used to decode images with high noise, such as subject-level activation maps, as our decoding model performs poorly on this type of data" could be clarified by varying the signal-to-noise ratio (SNR) in the data and taking group maps from different sample sizes. Nevertheless, as discussed above, the main issue is probably not the noise in the images but rather the difficulty of obtaining good embeddings for contrasts or tasks text descriptions. + +### **Response** + +We agree this analysis would be informative. In the revision, we will: +\- Compute group-average maps from subsets of N \= {5, 10, 20, 50, 100, 200, 787} subjects +\- Report decoding accuracy as a function of group size +\- Present this as a figure showing the relationship between effective SNR (or sample size) and decoding performance +\- Discuss whether the primary limitation is SNR or the covariate shift issue + +## **Reviewer 2 Comment 13** + +## " If one is interested in decoding both ends of the activation distribution in an image separately, one could flip the sign of the image to force the decoder to predict the negative tail." I'm not sure I understand the use case. + +### **Response** + +We will clarify this use case in the revision: + +"In task fMRI, statistical maps often contain both positively activated regions (task \> baseline) and negatively activated regions (baseline \> task). NiCLIP was trained primarily on positively activated coordinates. To decode the functional significance of deactivated regions, a user could invert the sign of the map so that deactivations become positive, then apply NiCLIP to characterize the cognitive processes associated with those deactivated areas. For example, decoding the inverted default mode network deactivation during a working memory task might reveal associations with resting-state or self-referential processing concepts." + +## **Reviewer 2 Comment 14** + +## Discussion: What does "continuous decoding" mean ? + +### **Response** + +We will replace "continuous decoding" with more precise language: + +"NiCLIP accurately performs functional decoding on whole-brain statistical maps (dense activation maps covering the full brain volume) as well as sparse brain images, such as regions of interest." + +## **Reviewer 2 Comment 15** + +## In the methods section, the authors define a "likelihood" P(A\_k|T) \= softmax(Emb(T).Emb(Ak)) + +## This is a heuristic; besides the fact that the result is a number between 0 and 1, it is unclear how it can be interpreted as a proper distribution because it is not calibrated as a proper probability. + +### **Response** + +The reviewer is correct. The softmax of cosine similarities produces a distribution that sums to 1, but this is not a calibrated probability in the statistical sense. It is a scoring function normalized to produce a probability-like output. We will revise the Methods to clarify this: + +"We note that P(A\_k|T) as defined by the softmax of cosine similarities in the CLIP latent space is a heuristic likelihood that quantifies the relative compatibility between activation patterns and task descriptions. While this normalized score shares properties with a probability distribution (non-negative, sums to 1), it is not calibrated as a proper statistical likelihood. Similarly, the posterior P(T|A) should be interpreted as a relative ranking score rather than a rigorously calibrated probability. This is consistent with the common use of softmax-normalized similarity scores in contrastive learning frameworks (Radford et al., 2021)." + +## **Reviewer 2 Comment 16** + +## The formula for the probability of a concept (C\_j ) given an activation A\_k seems to assume some independence of the P(T\_i |A\_k) probabilities, which likely does not hold. Therefore, these "probabilities" are, at best, a proxy and are not rigorous. This must be made clear. + +### **Response** + +The reviewer is correct that the noisy-OR model assumes conditional independence of the task probabilities given a concept, which is an approximation. We will add the following clarification: + +"The noisy-OR model used to compute P(C\_j|A\_k) assumes conditional independence among tasks that measure the same concept. This assumption is a simplification, as tasks sharing a concept may have correlated activation patterns. The resulting concept and domain probabilities should therefore be interpreted as approximate scores reflecting the aggregate evidence from related tasks, rather than as rigorously derived posterior probabilities. Despite this approximation, the noisy-OR model provides a principled mechanism for propagating task-level predictions through the ontological hierarchy, and its effectiveness is empirically validated by the meaningful concept and domain predictions observed in our results." + +# **Reviewer 3 Comments** + +Peraza and colleagues present NiCLIP, a CLIP-based neural network model, that learns to decode cognitive terms from coordinate-based brain activation maps. The model outputs predictions based on task, concept, and domain labels from the Cognitive Atlas. This provides a qualitative advance over previous reverse-inference models (like NeuroSynth) by learning a nonlinear mapping between brain images and words, taking advantage of the rich contextual-semantic representations encoded by large language models (LLMs). This manuscript is very timely and already in pretty good shape. The methodology seems solid and I believe the results. That said, I found certain bits of the text were difficult to follow; for example, I had a hard time understanding the distinguishing features of the models, which comparisons are most important, and which elements of NiCLIP are driving improved performance. Most of the following comments are clarification questions or suggestions that I think the authors can readily address. + +### **Response** + +We thank Reviewer \#3 for their positive assessment ("very timely," "methodology seems solid," "cool paper\!") and their constructive suggestions for improving clarity. We address each point below. + +## **Reviewer 3 Comment 1** + +In my first read through the manuscript, I had some difficulty following the narrative. I found myself asking questions like “wait, what exactly are the differences between the NiCLIP model and the CLIP model they were discussing in the previous section?” and “what exactly is this model trained and tested on? and is that different from the previous model?” Is the distinguishing feature of the “CLIP” model that it’s text-to-brain, whereas the distinguishing feature of the “NiCLIP” model is that it’s brain-to-text (and also trained on CogAtlas)? Couldn’t you theoretically also decode caption-style text directly from the CLIP model trained on brain images, without the CogAtlas? A good deal of this becomes somewhat clearer upon reading the Methods (at the end), but I think readers would benefit from a little more hand-holding throughout the Results. To be clear, I don’t think this is done poorly at all even in the current version—it’s just that this whole methodology is a complex beast, and very few readers will be familiar enough with all the different components to fully “get it” on the first read. Maybe introducing each section with a question or motivation sentence would help. + +### **Response** + +We appreciate this feedback about readability. In the revision, we will: + +1\. Add a "motivation sentence" at the beginning of each Results subsection, framing the question being addressed. For example: + \- Section 2.2: "Before evaluating functional decoding, we first assessed whether the CLIP framework can learn meaningful text-to-brain associations from the neuroimaging literature." + \- Section 2.3: "Given a trained text-to-brain CLIP model, can we perform functional decoding, predicting cognitive tasks from brain activation maps?" + +2\. Provide a clear conceptual distinction early in the Results: + \- CLIP model: The contrastive learning framework that aligns article text embeddings with brain activation map embeddings in a shared latent space. Trained on 23,865 PubMed article-coordinate pairs. Evaluated with Recall@K and Mix\&Match on held-out articles. + \- NiCLIP model: The decoding framework that uses the trained CLIP model for reverse inference. Instead of retrieving articles, it computes posterior probabilities for cognitive tasks from a structured ontology (Cognitive Atlas) by comparing their name/definition embeddings to a new brain map embedding in the CLIP latent space. + +3\. Add a "Couldn't you also decode caption-style text directly from the CLIP model?" paragraph explaining that yes, one could do nearest-neighbor text retrieval (as NeuroConText does), but NiCLIP's ontology-based approach provides structured, interpretable outputs at the task/concept/domain levels, which are more useful for hypothesis generation.. + +## **Reviewer 3 Comment 2** + +Following on the bit from the previous comment about training/testing, should readers be worried about potential leakage between the PubMed Central data used for training and HCP data used for testing the models? Isn’t it possible that some of the PMC training articles are reporting coordinates derived from exactly the same HCP data you use to test the model? I assume the authors ensure these are non-overlapping somehow (or maybe I just don’t fully understand the structure of the data), but I think this could be made more explicit. Related thought: I assume the articles are effectively randomized regarding topic, so that you don’t end up holding out a large chunk of articles on a single topic for a particular test set? + +### **Response** + +This is an important methodological concern. We want to clarify: + +1\. No coordinate-level leakage. NiCLIP is trained on article text paired with coordinates reported in those articles. Even if some training articles report HCP-derived results, the model does not see the HCP statistical maps during training — it only sees MKDA-modeled coordinate maps, which are a lossy representation of the original data. + +2\. Text-level potential overlap. Some PubMed articles may describe HCP task contrasts. However, these articles would describe many different analyses and findings from HCP, not just the specific contrasts we test on. The CLIP model learns general text-brain associations, not memorized mappings from specific contrasts. + +3\. Structural safeguard. The evaluation uses group-level t-statistic maps from NeuroVault (Collection 457), which are fundamentally different from the MKDA-modeled coordinate maps used in training. This structural mismatch between training (modeled coordinates) and evaluation (statistical maps) data makes direct leakage unlikely. + +4\. Topic randomization. Our 23-fold cross-validation splits articles randomly, so no systematic topic exclusion occurs. Articles are distributed across folds without topic-based stratification. + +We will add a paragraph in the Methods discussing potential leakage and these mitigating factors: + +\> \*"We note that the PubMed training corpus may include articles analyzing HCP data. However, several factors mitigate potential leakage: (1) the model is trained on MKDA-modeled coordinate maps, which are structurally distinct from the group-level t-statistic maps used for evaluation; (2) even if an article discusses an HCP contrast, the text embedding captures the full article content, not a one-to-one mapping to a specific statistical map; and (3) cross-validation folds are constructed by randomly sampling articles without topic stratification, ensuring no systematic bias toward HCP-related content."\* +. + +**Section: Materials and Methods** + **Pages 6-7** + +Importantly, . + +## **Reviewer 3 Comment 3** + +I had some difficulty understanding how exactly the language component of the model is encoding the text. For example, in my own work with LLMs, we’re often using the time series of word-by-word embeddings to capture the meaning of text. Does a single embedding for an entire article comprising thousands of words really capture all the nuances of meaning (the “deep semantic relationships” the authors advertise in the introduction) in that article? I could understand how a whole trajectory of word-by-word embeddings could capture the narrative of an article in a fairly rich, context-sensitive way—but wouldn’t you lose a good bit of this meaning and structure in collapsing the article into a single embedding? + +### **Response** + +This is an astute question about representation capacity. We will add a discussion: + +1\. How embeddings are computed: For long articles, we chunk the text into segments within the LLM's context window, compute an embedding for each chunk, and then average across chunks (Section 5.1.1). This mean-pooling approach does lose sequential/narrative structure but preserves the aggregate semantic content. + +2\. Why this works for our application: We are not trying to capture the narrative arc of an article. Rather, we need an embedding that represents \*what cognitive topics and brain regions the article discusses\*. Mean-pooled LLM embeddings have been shown to effectively capture document-level topics and semantic content in information retrieval tasks. + +3\. Empirical evidence: Our Table 1 results show that full-text (body) embeddings substantially outperform abstract-only embeddings (Recall@10: 33.56 vs. 24.01), indicating that the additional text does provide meaningful discriminative information beyond what's in the abstract. + +4\. Acknowledged limitation: We agree that richer text representations (e.g., attention-weighted pooling, multi-vector embeddings) could capture more nuanced semantic content. We will note this as a future direction. +. + +**Section: Discussion** + **Page 34** + +We recog. + +## **Reviewer 3 Comment 4** + +Table 2 is very dense. Can you hold the reader’s hand a bit more as to which numbers we should be comparing? For example, am I correct in understanding that GC-LDA task Recall@4 (17.14) outperforms NiCLIP task Recall@4 (10.71)? Isn’t this comparison a bit surprising? + +### **Response** + +We will substantially revise Table 2 for clarity: + +1\. Add a more detailed caption explaining what each row, column, and number represents. +2\. Highlight key comparisons (e.g., bold the best NiCLIP configuration, add asterisks for significant improvements over baselines). +3\. Add a summary row showing the best NiCLIP configuration vs. baselines. + +Regarding the specific comparison: If the reviewer is comparing GC-LDA task Recall@4 (17.14) with a specific NiCLIP configuration that shows 10.71, this is likely a configuration using the \*\*full (uncurated) Cognitive Atlas\*\* ontology, which performs poorly due to incomplete mappings. The best NiCLIP configuration (body \+ BrainGPT-7B-v0.2 \+ reduced CogAt \+ name+definition) achieves \*\*62.86% task Recall@4\*\*, which substantially outperforms GC-LDA's 17.14%. The key message is that ontology choice dramatically affects performance. We will add a note in the text explicitly guiding readers to the most important comparisons. + +. + +**Section: Results** + **Page 20** + +Overal + +## **Reviewer 3 Comment 5** + +The authors mention bag-of-words methods like TF-IDF in the Introduction, suggesting that LLMs will improve on this method. This set me up to expect a comparison to TF-IDF—but then I didn’t see it directly mentioned. Is the NeuroSynth baseline model effectively using TF-IDF? + +### **Response** + +Yes, the Neurosynth correlation decoder effectively uses TF-IDF. Neurosynth extracts term frequencies from article abstracts using automated text mining and creates meta-analytic maps by aggregating coordinates associated with each term — this is functionally equivalent to a TF-IDF-based approach. The GC-LDA baseline also uses term counts (a simpler form of bag-of-words). + +We will make this connection explicit in the Results: + +"Both baseline models rely on bag-of-words text representations: the Neurosynth correlation decoder uses automated term extraction akin to TF-IDF (Yarkoni et al., 2011), while GC-LDA uses raw term counts (Rubin et al., 2017). NiCLIP's substantial improvement over these baselines demonstrates the benefit of replacing bag-of-words representations with LLM-derived contextual embeddings." + +**Section: Limitations** + **Page 34-35** + +Second, the. + +## **Reviewer 3 Comment 6** + +Can you say a little bit more about the metrics (e.g., Recall@k, Mix\&Match at 2.2) when you first introduce them, without having to refer to the Methods? For example, in 2.3, you say, “In decoding, Recall@k represents…”—a nice, concise definition like this would also be useful earlier on. + +### **Response** + +Agreed. We will add concise metric definitions at first use in the Results: + +"We assessed the CLIP model using Recall@K and Mix\&Match (see Methods for formal definitions). Recall@K measures the percentage of test samples where the true text-image match appears among the top K ranked candidates — higher values indicate better retrieval. Mix\&Match assesses whether each brain map is more similar to its true corresponding text than to other texts in the set, serving as a discriminability measure." + +## **Reviewer 3 Comment 7** + +On page 9, you refer to “The reduced and enhanced versions of Cognitive Atlas.” What does “enhanced” mean here? Are you referring to two different versions (1 \= reduced, 2 \= enhanced) of the CogAtlas, or do you just mean the “reduced” version is also “enhanced”? + +### **Response** + +We apologize for the confusing wording. "Reduced" and "enhanced" refer to the same version: a Cognitive Atlas that is \*reduced\* in the number of tasks (retaining only popular tasks) and \*enhanced\* in its task-to-concept mappings (with enriched, manually curated connections). We will revise to: + +"The curated version of Cognitive Atlas, which retains the most commonly used fMRI tasks while enriching their concept mappings, consistently outperformed the original Cognitive Atlas ontology across all models." +. + +## **Reviewer 3 Comment 8** + +On page 9, you say “The predictions of domains consistently showed higher recall rates than tasks and concepts across all models and configurations…” Could this difference just be due to differences in the structure of these target variables? For example, maybe domains has fewer distinct elements than tasks, so it makes for an easier decoding task? Some kind of shuffled/permuted/null baseline could provide a useful point of comparison here. + +### **Response** + +This is correct. Domain prediction is an easier classification problem: +\- Domains: 10 categories +\- Tasks: \~400 categories (reduced CogAt) +\- Concepts: \~600 categories + +The higher domain Recall@2 therefore reflects both model informativeness and the smaller candidate set. To address this directly, we now compute permutation-based null baselines within each dataset/configuration rather than relying only on structural chance. + +For example, in the reduced cross-dataset benchmark: +1\. IBC combined: mean per-term hit@k was 0.201 for tasks versus a null mean of 0.101, 0.114 for concepts versus 0.055, and 0.284 for domains versus 0.221. +2\. CNP combined: mean per-term hit@k was 0.010 for tasks versus 0.0017, 0.0239 for concepts versus 0.0122, and 0.247 for domains versus 0.297. +3\. We also report normalized accuracy and the number of terms above chance. On IBC, 6/11 task terms, 17/30 concepts, and 6/9 domains were above chance in the combined setting, whereas on CNP only 1/12 concepts exceeded the null baseline. + +We will add a note in the manuscript explaining the structural difference between label levels and report both raw and chance-normalized scores. + +## **Reviewer 3 Comment 9** + +For the impact of this paper, I think it’s important to ask: How can people actually use this model? Can you provide a little more logistical details (or a recipe) for how others might use the trained model and code for their own research? (I see there are some pointers on the GitHub repo… Are authors planning to build a Neurosynth-style website for this?) + +### **Response** + +We will add a "Practical Usage" subsection in the Discussion describing: + +1\. Code and model availability: All code is available at https://github.com/NBCLab/brain-decoder. Trained model weights and vocabulary files will be released on OSF. + +2\. Quick-start recipe: + \`\`\`python + pip install braindec + from braindec.predict import image\_to\_labels + results \= image\_to\_labels( + "path/to/brain\_map.nii.gz", + model\_path="path/to/trained\_model.pth", + vocabulary=vocabulary, + vocabulary\_emb=vocabulary\_emb, + vocabulary\_prior=vocabulary\_prior, + ) + \`\`\` + +3\. Google Colab notebook demonstrating the complete workflow. + +4\. Future plans for a NeuroSynth Compose interface. + +## **Reviewer 3 Comment 10** + +Page 3: “Finally, we examined the extent to which NiCLIP’s capability in predicting subject-level activation maps.”—seems like there’s a word missing here. + +### **Response** + +Corrected to: \*"Finally, we examined the extent of NiCLIP's capability in predicting subject-level activation maps."\* + +## **Reviewer 3 Comment 11** + +5.1.2: “and may not always be factual”—I’m not sure what “factual” would mean in this context, but I get your point; maybe “widely agreed upon” is better? + +### **Response** + +Revised to: \*"As a community-based ontology, these mappings reflect the opinions of individual researchers and may not always be widely agreed upon or empirically validated."\* + +## **Reviewer 3 Comment 12** + +5.2: “The text encoder is characterized by a projection head and two residual heads, while the image encoder comprises three residual heads.” Is “head” the typical terminology here? I’m not super familiar with CLIP, but my brain is getting interference with the use of “head” in describing the attention heads in each layer? + +### **Response** + +We understand the potential confusion with attention heads. In our architecture, "head" refers to modular network blocks (fully connected layers with activation, dropout, and normalization), not attention heads. We will revise the terminology to avoid confusion: + +\> \*"The text encoder consists of a projection block and two residual blocks, while the image encoder comprises three residual blocks. Each block contains a fully connected layer, GELU activation, dropout, and layer normalization. We use 'block' rather than 'head' to distinguish these architectural components from the attention heads in transformer models."\* diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py new file mode 100644 index 0000000..792a324 --- /dev/null +++ b/tests/test_fetcher.py @@ -0,0 +1,117 @@ +import importlib.util +import os.path as op +import sys +import types +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +FETCHER_PATH = ROOT / "braindec" / "fetcher.py" + + +def _load_fetcher_module(monkeypatch, tmp_path): + braindec_pkg = types.ModuleType("braindec") + utils_mod = types.ModuleType("braindec.utils") + utils_mod.get_data_dir = lambda data_dir=None: str(tmp_path if data_dir is None else Path(data_dir)) + + monkeypatch.setitem(sys.modules, "braindec", braindec_pkg) + monkeypatch.setitem(sys.modules, "braindec.utils", utils_mod) + + spec = importlib.util.spec_from_file_location("fetcher_under_test", FETCHER_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_materialized_path_to_local_path(monkeypatch, tmp_path): + fetcher = _load_fetcher_module(monkeypatch, tmp_path) + destination = fetcher._materialized_path_to_local_path( + "/results/pubmed/model.pth", + tmp_path / "downloads", + ) + assert destination == tmp_path / "downloads" / "results" / "pubmed" / "model.pth" + + +def test_download_bundle_dispatches_all_assets(monkeypatch, tmp_path): + fetcher = _load_fetcher_module(monkeypatch, tmp_path) + calls = [] + + def fake_download_asset(name, destination_root=".", overwrite=False, node_id=None, timeout=60): + calls.append((name, Path(destination_root), overwrite, node_id, timeout)) + return [Path(destination_root) / name] + + monkeypatch.setattr(fetcher, "download_asset", fake_download_asset) + + downloaded = fetcher.download_bundle("example_prediction", destination_root=tmp_path, overwrite=True) + expected_assets = fetcher.OSF_BUNDLES["example_prediction"] + assert [call[0] for call in calls] == expected_assets + assert downloaded == [tmp_path / asset for asset in expected_assets] + + +def test_download_osf_folder_recurses_into_nested_folders(monkeypatch, tmp_path): + fetcher = _load_fetcher_module(monkeypatch, tmp_path) + + def fake_get_folder_item(node_id, remote_path, provider=None, timeout=60): + assert remote_path == "results/pubmed" + return {"id": "root-folder", "materialized_path": "/results/pubmed/"} + + def fake_iter_children(node_id=None, folder_id=None, provider=None, timeout=60): + if folder_id == "root-folder": + return iter( + [ + { + "id": "nested-folder", + "attributes": { + "kind": "folder", + "name": "subdir", + "materialized_path": "/results/pubmed/subdir/", + }, + "links": {}, + }, + { + "id": "file-a", + "attributes": { + "kind": "file", + "name": "model-a.pth", + "materialized_path": "/results/pubmed/model-a.pth", + }, + "links": {"download": "https://example.test/model-a"}, + }, + ] + ) + if folder_id == "nested-folder": + return iter( + [ + { + "id": "file-b", + "attributes": { + "kind": "file", + "name": "model-b.pth", + "materialized_path": "/results/pubmed/subdir/model-b.pth", + }, + "links": {"download": "https://example.test/model-b"}, + } + ] + ) + raise AssertionError(f"Unexpected folder_id {folder_id}") + + downloaded = [] + + def fake_download_to_file(url, destination, overwrite=False, timeout=60, chunk_size=None): + downloaded.append((url, Path(destination), overwrite)) + return Path(destination) + + monkeypatch.setattr(fetcher, "_get_folder_item", fake_get_folder_item) + monkeypatch.setattr(fetcher, "_iter_children", fake_iter_children) + monkeypatch.setattr(fetcher, "_download_to_file", fake_download_to_file) + + results = fetcher.download_osf_folder("results/pubmed", destination_root=tmp_path, overwrite=True) + + assert results == [ + tmp_path / "results" / "pubmed" / "model-a.pth", + tmp_path / "results" / "pubmed" / "subdir" / "model-b.pth", + ] + assert downloaded == [ + ("https://example.test/model-a", tmp_path / "results" / "pubmed" / "model-a.pth", True), + ("https://example.test/model-b", tmp_path / "results" / "pubmed" / "subdir" / "model-b.pth", True), + ] diff --git a/tests/test_jobs_utils.py b/tests/test_jobs_utils.py new file mode 100644 index 0000000..a56c41a --- /dev/null +++ b/tests/test_jobs_utils.py @@ -0,0 +1,50 @@ +import argparse +import os.path as op +import sys + +ROOT = op.abspath(op.join(op.dirname(__file__), "..")) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +from jobs import utils + + +def test_str_to_bool_accepts_common_true_false_values(): + assert utils.str_to_bool("true") is True + assert utils.str_to_bool("YES") is True + assert utils.str_to_bool("0") is False + assert utils.str_to_bool("No") is False + + +def test_str_to_bool_rejects_unknown_value(): + try: + utils.str_to_bool("maybe") + except argparse.ArgumentTypeError: + return + + raise AssertionError("Expected argparse.ArgumentTypeError for unknown boolean input.") + + +def test_get_default_project_dir_points_to_repo_root(): + project_dir = utils.get_default_project_dir() + assert op.basename(project_dir) == "brain-decoder" + assert op.isdir(op.join(project_dir, "jobs")) + assert op.isfile(op.join(project_dir, "pyproject.toml")) + + +def test_resolve_project_paths_defaults_to_repo_layout(): + project_dir, data_dir, results_dir = utils.resolve_project_paths() + assert data_dir == op.join(project_dir, "data") + assert results_dir == op.join(project_dir, "results") + + +def test_add_common_job_args_uses_expected_defaults(): + parser = utils.add_common_job_args(argparse.ArgumentParser()) + args = parser.parse_args([]) + + assert args.project_dir == utils.get_default_project_dir() + assert args.sections == utils.DEFAULT_SECTIONS + assert args.model_ids == utils.DEFAULT_MODEL_IDS + assert args.reduced is True + assert args.topk == 20 + assert args.standardize is False diff --git a/tests/test_predict.py b/tests/test_predict.py new file mode 100644 index 0000000..22e9b2c --- /dev/null +++ b/tests/test_predict.py @@ -0,0 +1,57 @@ +import numpy as np +import pandas as pd +import torch + +from braindec import predict + + +class _FakeModel: + logit_scale = torch.tensor(1.0) + + def __call__(self, image_input, text_inputs): + return image_input, text_inputs + + +def _call_image_to_labels(return_posterior_probability=False): + vocabulary = ["motor", "language"] + vocabulary_emb = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + prior = np.array([0.5, 0.5], dtype=np.float32) + + return predict.image_to_labels( + image=object(), + model_path="unused", + vocabulary=vocabulary, + vocabulary_emb=vocabulary_emb, + prior_probability=prior, + model=_FakeModel(), + topk=2, + return_posterior_probability=return_posterior_probability, + ) + + +def test_image_to_labels_returns_dataframe_by_default(monkeypatch): + monkeypatch.setattr(predict, "_get_device", lambda: torch.device("cpu")) + monkeypatch.setattr( + predict, + "preprocess_image", + lambda image, **kwargs: torch.tensor([[1.0, 0.0]], dtype=torch.float32), + ) + + result = _call_image_to_labels() + + assert isinstance(result, pd.DataFrame) + assert result.iloc[0]["pred"] == "motor" + + +def test_image_to_labels_can_return_posterior_probability(monkeypatch): + monkeypatch.setattr(predict, "_get_device", lambda: torch.device("cpu")) + monkeypatch.setattr( + predict, + "preprocess_image", + lambda image, **kwargs: torch.tensor([[1.0, 0.0]], dtype=torch.float32), + ) + + task_df, posterior_probability = _call_image_to_labels(return_posterior_probability=True) + + assert isinstance(task_df, pd.DataFrame) + assert isinstance(posterior_probability, torch.Tensor)