coco_pipe.decoding.foundation_models.estimators#
Clone-safe sklearn estimators backed by lazily loaded foundation models.
Classes#
Target-independent frozen feature extractor suitable for sklearn pipelines. |
|
Lazy trainable foundation-model classifier with grouped validation. |
Functions#
Empty the shared frozen-backbone embedding cache. |
Module Contents#
- coco_pipe.decoding.foundation_models.estimators.clear_frozen_embedding_cache()#
Empty the shared frozen-backbone embedding cache.
- Return type:
None
- class coco_pipe.decoding.foundation_models.estimators.FrozenBackboneTransformer(model_key, backend='auto', device='auto', pooling='mean', sfreq=None, ch_names=None, cache_embeddings=False, backend_kwargs=None)#
Bases:
sklearn.base.BaseEstimator,sklearn.base.TransformerMixinTarget-independent frozen feature extractor suitable for sklearn pipelines.
- Parameters:
- backend_ = None#
- prepared_ = None#
- backend = 'auto'#
- device = 'auto'#
- pooling = 'mean'#
- sfreq = None#
- ch_names = None#
- cache_embeddings = False#
- backend_kwargs = None#
- fit(X, y=None)#
- Parameters:
X (numpy.ndarray)
y (numpy.ndarray | None)
- transform(X)#
- Parameters:
X (numpy.ndarray)
- Return type:
- class coco_pipe.decoding.foundation_models.estimators.FoundationClassifier(model_key, backend='auto', train_mode='full', device='auto', n_outputs=None, sfreq=None, ch_names=None, trainer=None, lora=None, backend_kwargs=None, checkpoints=None, class_weight='balanced', random_state=42)#
Bases:
sklearn.base.BaseEstimator,sklearn.base.ClassifierMixinLazy trainable foundation-model classifier with grouped validation.
- Parameters:
- backend_ = None#
- prepared_ = None#
- checkpoint_path_ = None#
- backend = 'auto'#
- train_mode = 'full'#
- device = 'auto'#
- n_outputs = None#
- sfreq = None#
- ch_names = None#
- trainer = None#
- lora = None#
- backend_kwargs = None#
- checkpoints = None#
- class_weight = 'balanced'#
- random_state = 42#
- fit(X, y, groups=None)#
- Parameters:
X (numpy.ndarray)
y (numpy.ndarray)
groups (numpy.ndarray | None)
- restore_checkpoint(path=None)#
Restore every saved backend component into this fitted estimator.
- Parameters:
path (str | pathlib.Path | None)
- predict(X)#
- Parameters:
X (numpy.ndarray)
- Return type:
- predict_proba(X)#
- Parameters:
X (numpy.ndarray)
- Return type: