Source code for heavyedge_classify.model

"""MiniRocket-based probabilistic classifier of 1D signals."""

import warnings

from aeon.transformations.collection.convolution_based import MiniRocket
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import RidgeClassifierCV
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline

from .calibration import (
    IsotonicOvOCalibratedClassifierCV,
    SigmoidOvOCalibratedClassifierCV,
    TemperatureCalibratedClassifierCV,
)

__all__ = [
    "minirocket_classifier",
]


[docs] def minirocket_classifier( cv=5, calibration="sigmoid", n_jobs=None, verbose=False, random_state=0, n_splits=None, ): """MiniRocket-based probabilistic classifier of 1D signals. Parameters ---------- cv : int, iterable, or cross-validation generator, default=5 Cross-validation strategy. If an integer is passed, it is the number of folds for stratified k-fold CV. calibration : {"sigmoid", "isotonic", "temperature", "sigmoid_ovo", "isotonic_ovo"} Calibration method for the classifier. n_jobs : int, default=None Number of jobs to run in parallel. verbose : bool, default=False Prints pipeline steps. random_state : int, default=0 Random seed for reproducibility. n_splits : int, optional Number of splits for cross-validation. If passed, overrides *cv*. .. deprecated:: 1.4.0 The *n_splits* parameter is deprecated and will be removed in a future version. Use *cv* instead. Returns ------- model MiniRocket-based probabilistic classifier. Examples -------- >>> from heavyedge import ProfileData >>> from heavyedge_classify.samples import get_sample_path >>> from heavyedge_classify.model import minirocket_classifier >>> import numpy as np >>> model = minirocket_classifier(cv=5, random_state=42) >>> X, _, _ = ProfileData(get_sample_path("Profiles.h5"))[:] >>> y = np.load(get_sample_path("labels.npy")) >>> model.fit(X[:5], y[:5]) CalibratedClassifierCV(...) """ if n_splits is not None: warnings.warn( ( "n_splits is deprecated and will be removed in a future version. " "Use cv instead." ), DeprecationWarning, stacklevel=2, ) cv = n_splits pipeline = Pipeline( [ ("minirocket", MiniRocket(random_state=random_state)), ("classifier", RidgeClassifierCV(class_weight="balanced")), ], verbose=verbose, ) if isinstance(cv, int): cv = StratifiedKFold(n_splits=cv, shuffle=True, random_state=random_state) if calibration in ("sigmoid", "isotonic"): model = CalibratedClassifierCV( estimator=pipeline, method=calibration, cv=cv, n_jobs=n_jobs, ) elif calibration == "temperature": model = TemperatureCalibratedClassifierCV( estimator=pipeline, cv=cv, n_jobs=n_jobs, ) elif calibration == "sigmoid_ovo": model = SigmoidOvOCalibratedClassifierCV( estimator=pipeline, cv=cv, n_jobs=n_jobs, ) elif calibration == "isotonic_ovo": model = IsotonicOvOCalibratedClassifierCV( estimator=pipeline, cv=cv, n_jobs=n_jobs, ) else: raise ValueError(f"Unsupported calibration method: {calibration}") return model