Source code for features.bandpower

import numpy as np
from collections.abc import Sequence
from typing import TYPE_CHECKING
from pydantic import field_validator
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
from py_neuromodulation.utils.pydantic_extensions import (
    NMField,
    NMErrorList,
    create_validation_error,
)
from py_neuromodulation import logger

if TYPE_CHECKING:
    from py_neuromodulation.stream.settings import NMSettings
    from py_neuromodulation.filter import KalmanSettings


class BandpowerFeatures(BoolSelector):
    activity: bool = True
    mobility: bool = False
    complexity: bool = False


class BandPowerSettings(NMBaseModel):
    segment_lengths_ms: dict[str, int] = NMField(
        default={
            "theta": 1000,
            "alpha": 500,
            "low beta": 333,
            "high beta": 333,
            "low gamma": 100,
            "high gamma": 100,
            "HFA": 100,
        },
        custom_metadata={"field_type": "FrequencySegmentLength"},
    )
    bandpower_features: BandpowerFeatures = BandpowerFeatures()
    log_transform: bool = True
    kalman_filter: bool = False

    @field_validator("bandpower_features")
    @classmethod
    def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
        if not len(bandpower_features.get_enabled()) > 0:
            raise create_validation_error(
                error_message="Set at least one bandpower_feature to True.",
                location=["bandpass_filter_settings", "bandpower_features"],
            )
        return bandpower_features

    def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
        """_summary_

        :param settings: _description_
        :type settings: NMSettings
        :raises create_validation_error: _description_
        :raises create_validation_error: _description_
        :raises ValueError: _description_
        """
        errors: NMErrorList = NMErrorList()

        for fband_name, seg_length_fband in self.segment_lengths_ms.items():
            # Check that all frequency bands are defined in settings.frequency_ranges_hz
            if fband_name not in settings.frequency_ranges_hz:
                logger.warning(
                    f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
                    " is not defined in settings.frequency_ranges_hz"
                )

            # Check that all segment lengths are smaller than settings.segment_length_features_ms
            if not seg_length_fband <= settings.segment_length_features_ms:
                errors.add_error(
                    f"segment length {seg_length_fband} needs to be smaller than "
                    f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
                    location=[
                        "bandpass_filter_settings",
                        "segment_lengths_ms",
                        fband_name,
                    ],
                )

        # Check that all frequency bands defined in settings.frequency_ranges_hz
        for fband_name in settings.frequency_ranges_hz.keys():
            if fband_name not in self.segment_lengths_ms:
                errors.add_error(
                    f"frequency range {fband_name} "
                    "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
                    location=[
                        "bandpass_filter_settings",
                        "segment_lengths_ms",
                        fband_name,
                    ],
                )

        return errors


[docs] class BandPower(NMFeature): def __init__( self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float, use_kf: bool | None = None, ) -> None: settings.validate() self.bp_settings: BandPowerSettings = settings.bandpass_filter_settings self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings self.sfreq = sfreq self.ch_names = ch_names self.KF_dict: dict = {} from py_neuromodulation.filter import MNEFilter self.bandpass_filter = MNEFilter( f_ranges=[ tuple(frange) for frange in settings.frequency_ranges_hz.values() ], sfreq=self.sfreq, filter_length=self.sfreq - 1, verbose=False, ) if use_kf or (use_kf is None and self.bp_settings.kalman_filter): self.init_KF("bandpass_activity") seglengths = self.bp_settings.segment_lengths_ms self.feature_params = [] for ch_idx, ch_name in enumerate(self.ch_names): for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()): seglength_ms = seglengths[f_band] seglen = int(np.floor(self.sfreq / 1000 * seglength_ms)) for bp_feature in self.bp_settings.bandpower_features.get_enabled(): feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band]) self.feature_params.append( ( ch_idx, f_band_idx, seglen, bp_feature, feature_name, ) ) def init_KF(self, feature: str) -> None: from py_neuromodulation.filter import define_KF for f_band in self.kalman_filter_settings.frequency_bands: for channel in self.ch_names: self.KF_dict["_".join([channel, feature, f_band])] = define_KF( self.kalman_filter_settings.Tp, self.kalman_filter_settings.sigma_w, self.kalman_filter_settings.sigma_v, ) def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating: if KF_name in self.KF_dict: self.KF_dict[KF_name].predict() self.KF_dict[KF_name].update(feature_calc) feature_calc = self.KF_dict[KF_name].x[0] return feature_calc
[docs] def calc_feature(self, data: np.ndarray) -> dict: data = self.bandpass_filter.filter_data(data) feature_results = {} for ( ch_idx, f_band_idx, seglen, bp_feature, feature_name, ) in self.feature_params: feature_results[feature_name] = self.calc_bp_feature( bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:] ) return feature_results
def calc_bp_feature(self, bp_feature, feature_name, data): match bp_feature: case "activity": feature_calc = np.var(data) if self.bp_settings.log_transform: feature_calc = np.log10(feature_calc) if self.KF_dict: feature_calc = self.update_KF(feature_calc, feature_name) case "mobility": feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data)) case "complexity": feature_calc = self.calc_complexity(data) case _: raise ValueError(f"Unknown bandpower feature: {bp_feature}") return np.nan_to_num(feature_calc) @staticmethod def calc_complexity(data: np.ndarray) -> float: dat_deriv = np.diff(data) deriv_variance = np.var(dat_deriv) mobility = np.sqrt(deriv_variance / np.var(data)) dat_deriv_2_var = np.var(np.diff(dat_deriv)) deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance) return deriv_mobility / mobility