Source code for features.fooof
from collections.abc import Iterable
import numpy as np
from typing import TYPE_CHECKING
from py_neuromodulation.utils.types import (
NMBaseModel,
NMFeature,
BoolSelector,
FrequencyRange,
)
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
class FooofAperiodicSettings(BoolSelector):
exponent: bool = True
offset: bool = True
knee: bool = True
class FooofPeriodicSettings(BoolSelector):
center_frequency: bool = False
band_width: bool = False
height_over_ap: bool = False
class FooofSettings(NMBaseModel):
aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
periodic: FooofPeriodicSettings = FooofPeriodicSettings()
windowlength_ms: float = 800
peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
max_n_peaks: int = 3
min_peak_height: float = 0
peak_threshold: float = 2
freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
knee: bool = True
[docs]
class FooofAnalyzer(NMFeature):
feat_name_map = {
"exponent": "exp",
"offset": "offset",
"knee": "knee_frequency",
"center_frequency": "cf",
"band_width": "bw",
"height_over_ap": "pw",
}
def __init__(
self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
) -> None:
self.settings = settings.fooof_settings
self.sfreq = sfreq
self.ch_names = ch_names
self.ap_mode = "knee" if self.settings.knee else "fixed"
self.num_samples = int(self.settings.windowlength_ms * sfreq / 1000)
self.f_vec = np.arange(0, int(self.num_samples / 2) + 1, 1)
assert (
self.settings.windowlength_ms <= settings.segment_length_features_ms
), f"fooof windowlength_ms ({settings.fooof.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
assert (
self.settings.freq_range_hz[0] < sfreq
and self.settings.freq_range_hz[1] < sfreq
), f"fooof frequency range needs to be below sfreq, got {settings.fooof.freq_range_hz}"
from fooof import FOOOFGroup
self.fm = FOOOFGroup(
aperiodic_mode=self.ap_mode,
peak_width_limits=tuple(self.settings.peak_width_limits),
max_n_peaks=self.settings.max_n_peaks,
min_peak_height=self.settings.min_peak_height,
peak_threshold=self.settings.peak_threshold,
verbose=False,
)
[docs]
def calc_feature(self, data: np.ndarray) -> dict:
from scipy.fft import rfft
spectra = np.abs(rfft(data[:, -self.num_samples :])) # type: ignore
self.fm.fit(self.f_vec, spectra, self.settings.freq_range_hz)
if not self.fm.has_model or self.fm.null_inds_ is None:
raise RuntimeError("FOOOF failed to fit model to data.")
failed_fits: list[int] = self.fm.null_inds_
feature_results = {}
for ch_idx, ch_name in enumerate(self.ch_names):
FIT_PASSED = ch_idx not in failed_fits
exp = self.fm.get_params("aperiodic_params", "exponent")[ch_idx]
for feat in self.settings.aperiodic.get_enabled():
f_name = f"{ch_name}_fooof_a_{self.feat_name_map[feat]}"
if not FIT_PASSED:
feature_results[f_name] = None
elif feat == "knee" and exp == 0:
feature_results[f_name] = None
else:
params = self.fm.get_params("aperiodic_params", feat)[ch_idx]
if feat == "knee":
# If knee parameter is negative, set knee frequency to 0
if params < 0:
params = 0
else:
params = params ** (1 / exp)
feature_results[f_name] = np.nan_to_num(params)
peaks_dict: dict[str, np.ndarray | None] = {
"bw": self.fm.get_params("peak_params", "BW") if FIT_PASSED else None,
"cf": self.fm.get_params("peak_params", "CF") if FIT_PASSED else None,
"pw": self.fm.get_params("peak_params", "PW") if FIT_PASSED else None,
}
if type(peaks_dict["bw"]) is np.float64 or peaks_dict["bw"] is None:
peaks_dict["bw"] = [peaks_dict["bw"]]
peaks_dict["cf"] = [peaks_dict["cf"]]
peaks_dict["pw"] = [peaks_dict["pw"]]
for peak_idx in range(self.settings.max_n_peaks):
for feat in self.settings.periodic.get_enabled():
f_name = f"{ch_name}_fooof_p_{peak_idx}_{self.feat_name_map[feat]}"
feature_results[f_name] = (
peaks_dict[self.feat_name_map[feat]][peak_idx]
if peak_idx < len(peaks_dict[self.feat_name_map[feat]])
else None
)
return feature_results