"""Module for handling settings."""
from pathlib import Path
from typing import ClassVar
from pydantic import Field, model_validator
from py_neuromodulation import PYNM_DIR, logger, user_features
from py_neuromodulation.utils.types import (
BoolSelector,
FrequencyRange,
PreprocessorName,
_PathLike,
NMBaseModel,
NormMethod,
)
from py_neuromodulation.processing.filter_preprocessing import FilterSettings
from py_neuromodulation.processing.normalization import NormalizationSettings
from py_neuromodulation.processing.resample import ResamplerSettings
from py_neuromodulation.processing.projection import ProjectionSettings
from py_neuromodulation.filter import KalmanSettings
from py_neuromodulation.features import BispectraSettings
from py_neuromodulation.features import NoldsSettings
from py_neuromodulation.features import MNEConnectivitySettings
from py_neuromodulation.features import FooofSettings
from py_neuromodulation.features import CoherenceSettings
from py_neuromodulation.features import SharpwaveSettings
from py_neuromodulation.features import OscillatorySettings, BandPowerSettings
from py_neuromodulation.features import BurstsSettings
class FeatureSelection(BoolSelector):
raw_hjorth: bool = True
return_raw: bool = True
bandpass_filter: bool = False
stft: bool = False
fft: bool = True
welch: bool = True
sharpwave_analysis: bool = True
fooof: bool = False
nolds: bool = False
coherence: bool = False
bursts: bool = True
linelength: bool = True
mne_connectivity: bool = False
bispectrum: bool = False
class PostprocessingSettings(BoolSelector):
feature_normalization: bool = True
project_cortex: bool = False
project_subcortex: bool = False
[docs]
class NMSettings(NMBaseModel):
# Class variable to store instances
_instances: ClassVar[list["NMSettings"]] = []
# General settings
sampling_rate_features_hz: float = Field(default=10, gt=0)
segment_length_features_ms: float = Field(default=1000, gt=0)
frequency_ranges_hz: dict[str, FrequencyRange] = {
"theta": FrequencyRange(4, 8),
"alpha": FrequencyRange(8, 12),
"low_beta": FrequencyRange(13, 20),
"high_beta": FrequencyRange(20, 35),
"low_gamma": FrequencyRange(60, 80),
"high_gamma": FrequencyRange(90, 200),
"HFA": FrequencyRange(200, 400),
}
# Preproceessing settings
preprocessing: list[PreprocessorName] = [
"raw_resampling",
"notch_filter",
"re_referencing",
]
raw_resampling_settings: ResamplerSettings = ResamplerSettings()
preprocessing_filter: FilterSettings = FilterSettings()
raw_normalization_settings: NormalizationSettings = NormalizationSettings()
# Postprocessing settings
postprocessing: PostprocessingSettings = PostprocessingSettings()
feature_normalization_settings: NormalizationSettings = NormalizationSettings()
project_cortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=20)
project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5)
# Feature settings
features: FeatureSelection = FeatureSelection()
fft_settings: OscillatorySettings = OscillatorySettings()
welch_settings: OscillatorySettings = OscillatorySettings()
stft_settings: OscillatorySettings = OscillatorySettings()
bandpass_filter_settings: BandPowerSettings = BandPowerSettings()
kalman_filter_settings: KalmanSettings = KalmanSettings()
burst_settings: BurstsSettings = BurstsSettings()
sharpwave_analysis_settings: SharpwaveSettings = SharpwaveSettings()
mne_connectivity_settings: MNEConnectivitySettings = MNEConnectivitySettings()
coherence_settings: CoherenceSettings = CoherenceSettings()
fooof_settings: FooofSettings = FooofSettings()
nolds_settings: NoldsSettings = NoldsSettings()
bispectrum_settings: BispectraSettings = BispectraSettings()
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
for feat_name in user_features.keys():
setattr(self.features, feat_name, True)
NMSettings._add_instance(self)
@classmethod
def _add_instance(cls, instance: "NMSettings") -> None:
"""Keep track of all instances created in class variable"""
cls._instances.append(instance)
@classmethod
def _add_feature(cls, feature: str) -> None:
for instance in cls._instances:
setattr(instance.features, feature, True)
@classmethod
def _remove_feature(cls, feature: str) -> None:
for instance in cls._instances:
delattr(instance.features, feature)
@model_validator(mode="after")
def validate_settings(self):
if len(self.features.get_enabled()) == 0:
raise ValueError("At least one feature must be selected.")
# Replace spaces with underscores in frequency band names
self.frequency_ranges_hz = {
k.replace(" ", "_"): v for k, v in self.frequency_ranges_hz.items()
}
if self.features.bandpass_filter:
# Check BandPass settings frequency bands
self.bandpass_filter_settings.validate_fbands(self)
# Check Kalman filter frequency bands
if self.bandpass_filter_settings.kalman_filter:
self.kalman_filter_settings.validate_fbands(self)
for k, v in self.frequency_ranges_hz.items():
if not isinstance(v, FrequencyRange):
self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
return self
def reset(self) -> "NMSettings":
self.features.disable_all()
self.preprocessing = []
self.postprocessing.disable_all()
return self
def set_fast_compute(self) -> "NMSettings":
self.reset()
self.features.fft = True
self.preprocessing = [
"raw_resampling",
"notch_filter",
"re_referencing",
]
self.postprocessing.feature_normalization = True
self.postprocessing.project_cortex = False
self.postprocessing.project_subcortex = False
return self
def enable_all_features(self):
self.features.enable_all()
return self
def disable_all_features(self):
self.features.disable_all()
return self
@staticmethod
def get_fast_compute() -> "NMSettings":
return NMSettings.get_default().set_fast_compute()
@classmethod
def load(cls, settings: "NMSettings | _PathLike | None") -> "NMSettings":
if isinstance(settings, cls):
return settings.validate()
if settings is None:
return cls.get_default()
return cls.from_file(str(settings))
[docs]
@staticmethod
def from_file(PATH: _PathLike) -> "NMSettings":
"""Load settings from file or a directory.
Args:
PATH (_PathLike): Path to settings file or to directory containing settings file,
or path to experiment including experiment prefix
(e.g. /path/to/exp/exp_prefix[_SETTINGS.json]).
Supported file types are .json and .yaml
Raises:
ValueError: when file format is not supported.
Returns:
NMSettings: PyNM settings object
"""
path = Path(PATH)
# If directory is passed, look for settings file inside
if path.is_dir():
for child in path.iterdir():
if child.is_file() and child.suffix in [".json", ".yaml"]:
path = child
break
# If prefix is passed, look for settings file matching prefix
if not path.is_dir() and not path.is_file():
for child in path.parent.iterdir():
ext = child.suffix.lower()
if (
child.is_file()
and ext in [".json", ".yaml"]
and child.name == path.stem + "_SETTINGS" + ext
):
path = child
break
match path.suffix:
case ".json":
import json
with open(path) as f:
model_dict = json.load(f)
case ".yaml":
import yaml
# with open(path) as f:
# model_dict = yaml.safe_load(f)
# Timon: this is potentially dangerous since python code is directly executed
with open(path) as f:
model_dict = yaml.load(f, Loader=yaml.Loader)
case _:
raise ValueError("File format not supported.")
return NMSettings(**model_dict)
@staticmethod
def get_default() -> "NMSettings":
return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")
@staticmethod
def list_normalization_methods() -> list[NormMethod]:
return NormalizationSettings.list_normalization_methods()
def save(
self, out_dir: _PathLike = ".", prefix: str = "", format: str = "yaml"
) -> None:
filename = f"{prefix}_SETTINGS.{format}" if prefix else f"SETTINGS.{format}"
path_out = Path(out_dir) / filename
with open(path_out, "w") as f:
match format:
case "json":
f.write(self.model_dump_json(indent=4))
case "yaml":
import yaml
yaml.dump(self.model_dump(), f, default_flow_style=None)
logger.info(f"Settings saved to {path_out.resolve()}")
# For retrocompatibility with previous versions of PyNM
def get_default_settings() -> NMSettings:
return NMSettings.get_default()
def reset_settings(settings: NMSettings) -> NMSettings:
return settings.reset()
def get_fast_compute() -> NMSettings:
return NMSettings.get_fast_compute()
def test_settings(settings: NMSettings) -> NMSettings:
return settings.validate()