Research codebase for training and analyzing Multi-Plastic Networks (MPNs) on a battery of cognitive tasks. The core idea is that a single network with Hebbian-like synaptic plasticity can learn to solve many tasks simultaneously, and the structure of its plastic weights can be analyzed to understand how task-specific computation is organized.
The central model is DeepMultiPlasticNet (mpn.py), a recurrent network with one MultiPlasticLayer (mp_layer1) whose effective weights are modulated by a fast-timescale plasticity matrix M:
W_eff(t) = W + W ⊙ M(t) (multiplicative)
= W + M(t) (additive)
M evolves by a Hebbian-like rule with learnable parameters:
| Parameter | Symbol | Description |
|---|---|---|
| Learning rate | η (eta) | Scales the Hebbian update |
| Decay | λ (lam) | Controls timescale of synaptic memory |
Both η and λ can be scalar, pre-vector, post-vector, or full matrix.
The full network (DeepMultiPlasticNet) has three weight matrices:
W_initial_linear— input projection (pre-synaptic neurons)mp_layer1.W— recurrent plastic weights (hidden neurons)W_output— readout
Training → Analysis → Clustering → Lesion / Pruning
python multiple_task.pyTrains the MPN on a set of cognitive tasks defined in mpn_tasks.py. Saves:
multiple_tasks/savednet_{aname}.pt— model checkpointmultiple_tasks/param_{aname}_param.json— hyperparametersmultiple_tasks/param_{aname}_result.npz— training curves
Key hyperparameters (set inside the script):
hidden— number of recurrent unitsbatch— batch sizeseed— random seedfeature— regularization config (e.g.L21e4)
python multiple_task_analysis.pyLoads a trained model, evaluates it on all tasks, and produces:
- Task-conditioned activity matrices
- Cluster analysis of input and hidden neurons
- Low-dimensional (PCA) trajectory plots
- Saves
cluster_info_{aname}_normalized.pklfor downstream use
python clustering.pyImplements hierarchical clustering (clustering_metric.py) with silhouette-score-based automatic selection of the number of clusters k. Clusters neurons by their task-tuning profiles.
python leison.pyGiven a trained model and its cluster assignments, runs lesion experiments using a fixed number of clusters (FIXED_K, inferred from the upstream multiple_task_analysis.py pickle). The same dendrogram is cut at this fixed k for input, hidden, and modulation clusters, ensuring consistent granularity across all analyses.
Single-cluster lesion (input & hidden): For each neuron cluster (both normalized and unnormalized variants), zeros out all connections to/from that cluster and measures per-task accuracy. Input ("pre") and hidden ("post") clusters are each lesioned independently in leave-one-out fashion.
Random lesion: For each cluster lesion condition, lesions a size-matched random set of neurons as a control. The normalized lesion effect is computed as random_accuracy - cluster_accuracy.
Combined lesion (input × hidden): Simultaneously lesions one input cluster and one hidden cluster for all (pre_i, post_j) combinations. Random combined lesion serves as control.
Modulation lesion: For each modulation synapse cluster (derived from col_labels_by_k[FIXED_K]), two modes are tested:
zero_W: zeros the static weight W at cluster synapses (removes connectivity)freeze_M: keeps W intact but freezes plasticity M at those synapses (removes learning)
Magnitude pruning: Zeros the lowest-magnitude fraction of mp_layer1.W at increasing sparsity levels (0–99.9%) to assess how much of the plastic weight matrix is functionally necessary.
Results are saved to multiple_tasks_perf/{aname}/lesion_prune_results_{aname}.pkl.
python leison_plot.pyPost-processes the lesion results to compute normalized effects and cross-analyses:
- Normalized lesion heatmaps:
random - clustereffect for input/hidden and modulation clusters - Combined heatmaps: Side-by-side
zero_Wvsfreeze_Mwith shared color scale - Violin plots: Distribution of normalized effect across tasks per cluster
- Cluster similarity vs lesion effect: Correlates cluster tuning similarity with functional lesion similarity (tests whether similar clusters have similar roles)
- Overmembership vs lesion difference: Relates modulation cluster enrichment in (input, hidden) pairs to the functional similarity between modulation lesion and combined lesion effects
Outputs are saved to multiple_tasks_norm/{aname}/.
python state_space_shift.pyAnalyzes how the network's hidden-state geometry shifts across tasks using PCA and subspace angles.
| File | Purpose |
|---|---|
mpn.py |
Model definitions (MultiPlasticLayer, DeepMultiPlasticNet) |
mpn_tasks.py |
Task definitions and trial generators |
net_helpers.py |
Base network classes, weight initialization |
multiple_task.py |
Training loop |
multiple_task_analysis.py |
Post-training analysis and clustering pipeline |
clustering.py |
Hierarchical clustering with automatic k selection |
clustering_metric.py |
Cluster quality metrics |
leison.py |
Lesion and pruning experiments |
leison_plot.py |
Plotting utilities for lesion results |
state_space_shift.py |
State space / PCA analysis |
helper.py |
Shared utilities |
color_func.py |
Color palettes for plotting |
| Directory | Contents |
|---|---|
multiple_tasks/ |
Checkpoints, training curves, cluster info |
multiple_tasks_perf/ |
Lesion/pruning heatmaps and result pickles |
state_space/ |
State space figures |
- Python 3.9+
- PyTorch (CUDA optional, detected automatically in
leison.py) - NumPy, SciPy, scikit-learn
- Matplotlib, seaborn
- h5py, hdf5plugin
- scienceplots (for analysis notebooks)
Model checkpoints and result files use a shared identifier string:
{task}_seed{seed}_{feature}+hidden{hidden}+batch{batch}{accfeature}
# e.g. everything_seed749_L21e4+hidden300+batch128+angle
All analysis scripts read aname from this pattern to locate the correct files.
Parts of this codebase were written with the assistance of Claude Code.