coco_pipe.decoding.foundation_models.estimators#

Clone-safe sklearn estimators backed by lazily loaded foundation models.

Classes#

FrozenBackboneTransformer

Target-independent frozen feature extractor suitable for sklearn pipelines.

FoundationClassifier

Lazy trainable foundation-model classifier with grouped validation.

Functions#

clear_frozen_embedding_cache()

Empty the shared frozen-backbone embedding cache.

Module Contents#

coco_pipe.decoding.foundation_models.estimators.clear_frozen_embedding_cache()#

Empty the shared frozen-backbone embedding cache.

Return type:

None

class coco_pipe.decoding.foundation_models.estimators.FrozenBackboneTransformer(model_key, backend='auto', device='auto', pooling='mean', sfreq=None, ch_names=None, cache_embeddings=False, backend_kwargs=None)#

Bases: sklearn.base.BaseEstimator, sklearn.base.TransformerMixin

Target-independent frozen feature extractor suitable for sklearn pipelines.

Parameters:
backend_ = None#
prepared_ = None#
backend = 'auto'#
device = 'auto'#
pooling = 'mean'#
sfreq = None#
ch_names = None#
cache_embeddings = False#
backend_kwargs = None#
fit(X, y=None)#
Parameters:
transform(X)#
Parameters:

X (numpy.ndarray)

Return type:

numpy.ndarray

class coco_pipe.decoding.foundation_models.estimators.FoundationClassifier(model_key, backend='auto', train_mode='full', device='auto', n_outputs=None, sfreq=None, ch_names=None, trainer=None, lora=None, backend_kwargs=None, checkpoints=None, class_weight='balanced', random_state=42)#

Bases: sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin

Lazy trainable foundation-model classifier with grouped validation.

Parameters:
backend_ = None#
prepared_ = None#
checkpoint_path_ = None#
backend = 'auto'#
train_mode = 'full'#
device = 'auto'#
n_outputs = None#
sfreq = None#
ch_names = None#
trainer = None#
lora = None#
backend_kwargs = None#
checkpoints = None#
class_weight = 'balanced'#
random_state = 42#
fit(X, y, groups=None)#
Parameters:
restore_checkpoint(path=None)#

Restore every saved backend component into this fitted estimator.

Parameters:

path (str | pathlib.Path | None)

predict(X)#
Parameters:

X (numpy.ndarray)

Return type:

numpy.ndarray

predict_proba(X)#
Parameters:

X (numpy.ndarray)

Return type:

numpy.ndarray

get_training_history()#
Return type:

list[dict[str, Any]]

get_checkpoint_manifest()#
Return type:

dict[str, Any]

get_model_card_info()#
Return type:

dict[str, Any]

get_failure_diagnostics()#
Return type:

dict[str, Any]

get_artifact_metadata()#
Return type:

dict[str, Any]