from collections.abc import Iterable
import numpy as np
from typing import TYPE_CHECKING, Annotated, Literal
from pydantic import Field
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
MNE_CONNECTIVITY_METHOD = Literal[
"coh",
"cohy",
"imcoh",
"cacoh",
"mic",
"mim",
"plv",
"ciplv",
"ppc",
"pli",
"dpli",
"wpli",
"wpli2_debiased",
"gc",
"gc_tr",
]
MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
class MNEConnectivitySettings(NMBaseModel):
method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
channels: list[ListOfTwoStr] = []
[docs]
class MNEConnectivity(NMFeature):
def __init__(
self,
settings: "NMSettings",
ch_names: Iterable[str],
sfreq: float,
) -> None:
self.settings = settings
self.ch_names = ch_names
self.sfreq = sfreq
self.channels = settings.mne_connectivity_settings.channels
# Params used by spectral_connectivity_epochs
self.mode = settings.mne_connectivity_settings.mode
self.method = settings.mne_connectivity_settings.method
self.indices = ([], []) # convert channel names to channel indices in data
for con_idx in range(len(self.channels)):
seed_name = self.channels[con_idx][0]
target_name = self.channels[con_idx][1]
seed_name_reref = [ch for ch in self.ch_names if ch.startswith(seed_name)][0]
target_name_reref = [ch for ch in self.ch_names if ch.startswith(target_name)][0]
self.indices[0].append(self.ch_names.index(seed_name_reref))
self.indices[1].append(self.ch_names.index(target_name_reref))
self.fbands = settings.frequency_ranges_hz
self.fband_ranges: list = []
self.result_keys = []
self.prev_batch_shape: tuple = (-1, -1) # sentinel value
[docs]
def calc_feature(self, data: np.ndarray) -> dict:
from mne_connectivity import spectral_connectivity_epochs
# n_jobs is here kept to 1, since setup of the multiprocessing Pool
# takes longer than most batch computing sizes
spec_out = spectral_connectivity_epochs(
data=np.expand_dims(data, axis=0), # add singleton epoch dimension
sfreq=self.sfreq,
method=self.method,
mode=self.mode,
indices=self.indices,
verbose=False,
)
dat_conn: np.ndarray = spec_out.get_data()
# Get frequency band ranges only for the first batch, it's already the same
if len(self.fband_ranges) == 0:
for fband_range in self.fbands.values():
self.fband_ranges.append(
np.where(
(np.array(spec_out.freqs) >= fband_range[0])
& (np.array(spec_out.freqs) <= fband_range[1])
)[0]
)
feature_results = {}
for con_idx in np.arange(dat_conn.shape[0]):
for fband_idx, fband_name in enumerate(self.fbands):
# TODO: Add support for max_fband and max_allfbands
feature_results[
"_".join(
[
self.method,
self.channels[con_idx][0], # seed channel name
"to",
self.channels[con_idx][1], # target channel name
"mean_fband",
fband_name,
]
)
] = np.mean(dat_conn[con_idx, self.fband_ranges[fband_idx]])
# Store current experiment parameters to check if re-initialization is needed
self.prev_batch_shape = data.shape
return feature_results