Foundation Models#
coco_pipe.decoding integrates pretrained neural-network backbones for
EEG/MEG decoding (implemented in coco_pipe.decoding.foundation_models)
directly into the Experiment workflow, configured through
coco_pipe.decoding.configs just like classical estimators. Foundation
models can be used in three modes:
Frozen embedding extraction: extract features from a fixed backbone, then decode with a classical scikit-learn estimator.
Frozen backbone + trainable head: fine-tune only the output head.
Full fine-tuning: update all backbone parameters (LoRA, QLoRA, or full).
All modes enter through Experiment.run(...) and are compatible with the
outer CV loop, meaning the foundation model is fit/embedded inside the
training partition of each fold.
—
1. Embedding Extraction (Frozen Backbone)#
The simplest foundation model workflow: freeze the backbone and use it as a fixed feature extractor. A classical scikit-learn model decodes from the extracted embeddings.
from coco_pipe.decoding.configs import (
ExperimentConfig, CVConfig,
FoundationEmbeddingModelConfig,
FrozenBackboneDecoderConfig,
ClassicalModelConfig,
)
config = ExperimentConfig(
task="classification",
models={
"labram_probe": FrozenBackboneDecoderConfig(
backbone=FoundationEmbeddingModelConfig(
backend="braindecode",
model_key="labram",
pooling="mean",
cache_embeddings=True, # cache to disk for reuse across folds
),
head=ClassicalModelConfig(
estimator="LogisticRegression",
params={"max_iter": 1000},
),
)
},
metrics=["balanced_accuracy"],
cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"),
)
result = Experiment(config).run(X_epochs, y, sample_metadata=meta)
1.1 Embedding Cache#
cache_embeddings=True writes extracted embeddings to a content-addressed
disk cache. Because a frozen backbone is a deterministic, target-independent
function of the window content, the cache key is derived from the input windows
and the backbone fingerprint — so embeddings are computed once and reused across
folds and models without leakage. No manual cache management is required.
1.2 Supported Models & Backends#
coco-pipe integrates the following foundation models natively, registered in
coco_pipe.decoding._specs.ESTIMATOR_SPECS (browse them at runtime with
list_foundation_models()):
Model |
Hub repo |
Emb. dim |
sfreq |
Channels |
Interpolation |
Train modes |
Backend |
|---|---|---|---|---|---|---|---|
BENDR |
|
512 |
256.0 Hz |
varies |
no |
|
|
BIOT |
|
256 |
200.0 Hz |
varies |
no |
|
|
CBraMod |
|
200 |
200.0 Hz |
varies |
no |
|
|
EEGPT |
|
2048 |
250.0 Hz |
varies |
no |
|
|
LaBraM |
|
200 |
200.0 Hz |
128 |
yes |
|
|
LUNA |
|
256 |
200.0 Hz |
varies |
no |
|
|
REVE |
|
512 |
200.0 Hz |
varies |
no |
|
|
SignalJEPA |
|
64 |
200.0 Hz |
varies |
no |
|
|
—
2. Neural Fine-Tuning (LoRA / QLoRA)#
from coco_pipe.decoding.configs import (
NeuralFineTuneConfig, LoRAConfig, QuantizationConfig, DeviceConfig, CheckpointConfig
)
config = ExperimentConfig(
task="classification",
models={
"reve_qlora": NeuralFineTuneConfig(
backend="hugging_face",
model_key="reve",
input_kind="epoched",
train_mode="qlora",
lora=LoRAConfig(r=16, alpha=32, dropout=0.05),
quantization=QuantizationConfig(enabled=True, load_in_4bit=True),
device=DeviceConfig(device="auto", precision="bf16"),
checkpoints=CheckpointConfig(save="best"),
)
},
metrics=["balanced_accuracy"],
cv=CVConfig(strategy="stratified_group_kfold", n_splits=5, group_key="Subject"),
)
2.1 Training Modes#
Mode |
Description |
|---|---|
|
Update all backbone parameters. Highest capacity; requires most memory. |
|
Low-Rank Adaptation. Trains small rank-decomposed matrices injected into transformer attention. Memory-efficient. |
|
Quantized LoRA. Backbone quantized to 4-bit for inference; LoRA adapters trained in higher precision. Most memory-efficient option. |
2.2 LoRA Configuration#
Parameter |
Description |
|---|---|
|
Rank of the LoRA decomposition. Higher rank → more parameters. Default 16. |
|
Scaling factor. |
|
Dropout on LoRA layers. Default 0.0. |
—
3. Diagnostic Artifacts#
Trainable neural models expose training diagnostics via NeuralTrainable
protocol methods:
artifacts = result.get_model_artifacts()
# columns: Model, Fold, ArtifactKey, ArtifactValue
# Per-fold training history
history = result.get_model_artifacts(artifact_type="training_history")
# Checkpoint manifest
checkpoints = result.get_model_artifacts(artifact_type="checkpoints")
The NeuralTrainable protocol requires:
get_training_history() → list[dict]: loss/metric per epoch.get_checkpoint_manifest() → dict: saved checkpoint paths and best epoch.get_model_card_info() → dict: architecture and training summary.get_failure_diagnostics() → dict: NaN detection, gradient norms.get_artifact_metadata() → dict: aggregated artifact dictionary.
—
4. Required Dependencies#
Foundation models require optional extras:
braindecodeprovider:pip install coco-pipe[braindecode]huggingface/qloraprovider:pip install coco-pipe[hf,peft,quant]reveprovider: Contact the REVE team for access.
pip install coco-pipe[hf,peft,quant] # QLoRA path
pip install coco-pipe[braindecode] # Braindecode backbone