Source code for features.bispectra

from collections.abc import Iterable
from pydantic import field_validator
from typing import TYPE_CHECKING, Callable

import numpy as np

from py_neuromodulation.utils.types import (
    NMBaseModel,
    NMFeature,
    BoolSelector,
    FrequencyRange,
)

if TYPE_CHECKING:
    from py_neuromodulation import NMSettings


class BispectraComponents(BoolSelector):
    absolute: bool = True
    real: bool = True
    imag: bool = True
    phase: bool = True


class BispectraFeatures(BoolSelector):
    mean: bool = True
    sum: bool = True
    var: bool = True


class BispectraSettings(NMBaseModel):
    f1s: FrequencyRange = FrequencyRange(5, 35)
    f2s: FrequencyRange = FrequencyRange(5, 35)
    compute_features_for_whole_fband_range: bool = True
    frequency_bands: list[str] = ["theta", "alpha", "low_beta", "high_beta"]

    components: BispectraComponents = BispectraComponents()
    bispectrum_features: BispectraFeatures = BispectraFeatures()

    @field_validator("f1s", "f2s")
    def test_range(cls, filter_range):
        assert (
            filter_range[1] > filter_range[0]
        ), f"second frequency range value needs to be higher than first one, got {filter_range}"
        return filter_range

    @field_validator("frequency_bands")
    def fbands_spaces_to_underscores(cls, frequency_bands):
        return [f.replace(" ", "_") for f in frequency_bands]


FEATURE_DICT: dict[str, Callable] = {
    "mean": np.nanmean,
    "sum": np.nansum,
    "var": np.nanvar,
}

COMPONENT_DICT: dict[str, Callable] = {
    "real": lambda obj: getattr(obj, "real"),
    "imag": lambda obj: getattr(obj, "imag"),
    "absolute": np.abs,
    "phase": np.angle,
}


[docs] class Bispectra(NMFeature): def __init__( self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float ) -> None: self.sfreq = sfreq self.ch_names = ch_names self.frequency_ranges_hz = settings.frequency_ranges_hz self.settings: BispectraSettings = settings.bispectrum_settings assert all( f_band_bispectrum in settings.frequency_ranges_hz for f_band_bispectrum in self.settings.frequency_bands ), ( "bispectrum selected frequency bands don't match the ones" "specified in s['frequency_ranges_hz']" f"bispectrum frequency bands: {self.settings.frequency_bands}" f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}" ) self.used_features = self.settings.bispectrum_features.get_enabled() self.min_freq = min( self.settings.f1s.frequency_low_hz, self.settings.f2s.frequency_low_hz ) self.max_freq = max( self.settings.f1s.frequency_high_hz, self.settings.f2s.frequency_high_hz ) # self.freqs: np.ndarray = np.array([]) # In case we pre-computed this
[docs] def calc_feature(self, data: np.ndarray) -> dict: from pybispectra import compute_fft, WaveShape fft_coeffs, freqs = compute_fft( data=np.expand_dims(data, axis=0), sampling_freq=self.sfreq, n_points=data.shape[1], verbose=False, ) # freqs is batch independent, except for the last batch perhaps (if it has different shape) # but it's computed by compute_fft regardless so no advantage in pre-computing it # if not self.freqs = self.freqs = np.fft.rfftfreq(n=data.shape[1], d = 1 / sfreq) # fft_coeffs shape: [epochs, channels, frequencies] f_spectrum_range = freqs[ np.logical_and(freqs >= self.min_freq, freqs <= self.max_freq) ] waveshape = WaveShape( data=fft_coeffs, freqs=freqs, sampling_freq=self.sfreq, verbose=False, ) waveshape.compute( f1s=tuple(self.settings.f1s), # type: ignore f2s=tuple(self.settings.f2s), # type: ignore ) waveshape = waveshape.results.get_results(copy=False) # can overwrite obj with array feature_results = {} for ch_idx, ch_name in enumerate(self.ch_names): bispectrum = waveshape[ch_idx] for component in self.settings.components.get_enabled(): spectrum_ch = COMPONENT_DICT[component](bispectrum) for fb in self.settings.frequency_bands: range_ = (f_spectrum_range >= self.frequency_ranges_hz[fb][0]) & ( f_spectrum_range <= self.frequency_ranges_hz[fb][1] ) # waveshape.results.plot() data_bs = spectrum_ch[range_, range_] for bispectrum_feature in self.used_features: feature_results[ f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_{fb}" ] = FEATURE_DICT[bispectrum_feature](data_bs) if self.settings.compute_features_for_whole_fband_range: feature_results[ f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_whole_fband_range" ] = FEATURE_DICT[bispectrum_feature](spectrum_ch) return feature_results