Source code for features.coherence

import numpy as np
from collections.abc import Iterable


from typing import TYPE_CHECKING, Annotated
from pydantic import Field, field_validator

from py_neuromodulation.utils.types import (
    NMFeature,
    BoolSelector,
    FrequencyRange,
    NMBaseModel,
)
from py_neuromodulation import logger

if TYPE_CHECKING:
    from py_neuromodulation import NMSettings


class CoherenceMethods(BoolSelector):
    coh: bool = True
    icoh: bool = True


class CoherenceFeatures(BoolSelector):
    mean_fband: bool = True
    max_fband: bool = True
    max_allfbands: bool = True
    

ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]


class CoherenceSettings(NMBaseModel):
    features: CoherenceFeatures = CoherenceFeatures()
    method: CoherenceMethods = CoherenceMethods()
    channels: list[ListOfTwoStr] = []
    nperseg: int = Field(default=128, ge=0)
    frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)

    @field_validator("frequency_bands")
    def fbands_spaces_to_underscores(cls, frequency_bands):
        return [f.replace(" ", "_") for f in frequency_bands]


class CoherenceObject:
    def __init__(
        self,
        sfreq: float,
        window: str,
        fbands: list[FrequencyRange],
        fband_names: list[str],
        nperseg: int,
        ch_1_name: str,
        ch_2_name: str,
        ch_1_idx: int,
        ch_2_idx: int,
        coh: bool,
        icoh: bool,
        features_coh: CoherenceFeatures,
    ) -> None:
        self.sfreq = sfreq
        self.window = window
        self.fbands = fbands
        self.fband_names = fband_names
        self.ch_1 = ch_1_name
        self.ch_2 = ch_2_name
        self.ch_1_idx = ch_1_idx
        self.ch_2_idx = ch_2_idx
        self.nperseg = nperseg
        self.coh = coh
        self.icoh = icoh
        self.features_coh = features_coh

        self.Pxx = None
        self.Pyy = None
        self.Pxy = None
        self.f = None
        self.coh_val = None
        self.icoh_val = None

    def get_coh(self, feature_results, x, y):
        from scipy.signal import welch, csd

        self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=self.nperseg)
        self.Pyy = welch(y, self.sfreq, self.window, nperseg=self.nperseg)[1]
        self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=self.nperseg)[1]

        if self.coh:
            # XXX: gives different output to abs(Sxy) / sqrt(Sxx * Syy)
            self.coh_val = np.abs(self.Pxy) ** 2 / (self.Pxx * self.Pyy)
        if self.icoh:
            self.icoh_val = self.Pxy.imag / np.sqrt(self.Pxx * self.Pyy)

        for coh_idx, coh_type in enumerate([self.coh, self.icoh]):
            if coh_type:
                if coh_idx == 0:
                    coh_val = self.coh_val
                    coh_name = "coh"
                else:
                    coh_val = self.icoh_val
                    coh_name = "icoh"

            for idx, fband in enumerate(self.fbands):
                if self.features_coh.mean_fband:
                    feature_calc = np.mean(
                        coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
                    )
                    feature_name = "_".join(
                        [
                            coh_name,
                            self.ch_1,
                            "to",
                            self.ch_2,
                            "mean_fband",
                            self.fband_names[idx],
                        ]
                    )
                    feature_results[feature_name] = feature_calc
                if self.features_coh.max_fband:
                    feature_calc = np.max(
                        coh_val[np.bitwise_and(self.f > fband[0], self.f < fband[1])]
                    )
                    feature_name = "_".join(
                        [
                            coh_name,
                            self.ch_1,
                            "to",
                            self.ch_2,
                            "max_fband",
                            self.fband_names[idx],
                        ]
                    )
                    feature_results[feature_name] = feature_calc
            if self.features_coh.max_allfbands:
                feature_calc = self.f[np.argmax(coh_val)]
                feature_name = "_".join(
                    [
                        coh_name,
                        self.ch_1,
                        "to",
                        self.ch_2,
                        "max_allfbands",
                        self.fband_names[idx],
                    ]
                )
                feature_results[feature_name] = feature_calc
        return feature_results


[docs] class Coherence(NMFeature): def __init__( self, settings: "NMSettings", ch_names: list[str], sfreq: float ) -> None: self.settings = settings.coherence_settings self.frequency_ranges_hz = settings.frequency_ranges_hz self.sfreq = sfreq self.ch_names = ch_names self.coherence_objects: Iterable[CoherenceObject] = [] self.test_settings(settings, ch_names, sfreq) for idx_coh in range(len(self.settings.channels)): fband_names = self.settings.frequency_bands fband_specs = [] for band_name in fband_names: fband_specs.append(self.frequency_ranges_hz[band_name]) ch_1_name = self.settings.channels[idx_coh][0] ch_1_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_1_name)][ 0 ] ch_1_idx = self.ch_names.index(ch_1_name_reref) ch_2_name = self.settings.channels[idx_coh][1] ch_2_name_reref = [ch for ch in self.ch_names if ch.startswith(ch_2_name)][ 0 ] ch_2_idx = self.ch_names.index(ch_2_name_reref) self.coherence_objects.append( CoherenceObject( sfreq, "hann", fband_specs, fband_names, self.settings.nperseg, ch_1_name, ch_2_name, ch_1_idx, ch_2_idx, self.settings.method.coh, self.settings.method.icoh, self.settings.features, ) ) @staticmethod def test_settings( settings: "NMSettings", ch_names: Iterable[str], sfreq: float, ): flat_channels = [ ch for ch_pair in settings.coherence_settings.channels for ch in ch_pair ] valid_coh_channel = [ sum(ch.startswith(ch_coh) for ch in ch_names) for ch_coh in flat_channels ] for ch_idx, ch_coh in enumerate(flat_channels): if valid_coh_channel[ch_idx] == 0: raise RuntimeError( f"Coherence selected channel {ch_coh} does not match any channel name: \n" f" - settings.coherence_settings.channels: {settings.coherence_settings.channels}\n" f" - ch_names: {ch_names} \n" ) if valid_coh_channel[ch_idx] > 1: raise RuntimeError( f"Coherence selected channel {ch_coh} is ambigous and matches more than one channel name: \n" f" - settings.coherence_settings.channels: {settings.coherence_settings.channels}\n" f" - ch_names: {ch_names} \n" ) assert all( f_band_coh in settings.frequency_ranges_hz for f_band_coh in settings.coherence_settings.frequency_bands ), ( "coherence selected frequency bands don't match the ones" "specified in s['frequency_ranges_hz']" f"coherence frequency bands: {settings.coherence_settings.frequency_bands}" f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}" ) assert all( settings.frequency_ranges_hz[fb][0] < sfreq / 2 and settings.frequency_ranges_hz[fb][1] < sfreq / 2 for fb in settings.coherence_settings.frequency_bands ), ( "the coherence frequency band ranges need to be smaller than the Nyquist frequency" f"got sfreq = {sfreq} and fband ranges {settings.coherence_settings.frequency_bands}" ) if not settings.coherence_settings.method.get_enabled(): logger.warn( "feature coherence enabled, but no coherence['method'] selected" )
[docs] def calc_feature(self, data: np.ndarray) -> dict: feature_results = {} for coh_obj in self.coherence_objects: feature_results = coh_obj.get_coh( feature_results, data[coh_obj.ch_1_idx, :], data[coh_obj.ch_2_idx, :], ) return feature_results