Source code for filter.notch_filter

import numpy as np
from typing import cast

from py_neuromodulation.utils.types import NMPreprocessor
from py_neuromodulation import logger


[docs] class NotchFilter(NMPreprocessor): def __init__( self, sfreq: float, line_noise: float | None = None, freqs: np.ndarray | None = None, notch_widths: int | np.ndarray | None = 3, trans_bandwidth: float = 6.8, ) -> None: from mne.filter import create_filter if line_noise is None and freqs is None: raise ValueError( "Either line_noise or freqs must be defined if notch_filter is" "activated." ) if freqs is None: freqs = np.arange(line_noise, sfreq / 2, line_noise, dtype=int) if freqs.size > 0 and freqs[-1] >= sfreq / 2: freqs = freqs[:-1] # Code is copied from filter.py notch_filter if freqs.size == 0: self.filter_bank = None logger.warning( "WARNING: notch_filter is activated but data is not being" " filtered. This may be due to a low sampling frequency or" " incorrect specifications. Make sure your settings are" f" correct. Got: {sfreq = }, {line_noise = }, {freqs = }." ) return filter_length = int(sfreq - 1) if notch_widths is None: notch_widths = freqs / 200.0 elif np.any(notch_widths < 0): raise ValueError("notch_widths must be >= 0") else: notch_widths = np.atleast_1d(notch_widths) if len(notch_widths) == 1: notch_widths = notch_widths[0] * np.ones_like(freqs) elif len(notch_widths) != len(freqs): raise ValueError( "notch_widths must be None, scalar, or the " "same length as freqs" ) notch_widths = cast(np.ndarray, notch_widths) # For MyPy only, no runtime cost # Speed this up by computing the fourier coefficients once tb_half = trans_bandwidth / 2.0 lows = [freq - nw / 2.0 - tb_half for freq, nw in zip(freqs, notch_widths)] highs = [freq + nw / 2.0 + tb_half for freq, nw in zip(freqs, notch_widths)] self.filter_bank = create_filter( data=None, sfreq=sfreq, l_freq=highs, h_freq=lows, filter_length=filter_length, # type: ignore l_trans_bandwidth=tb_half, # type: ignore h_trans_bandwidth=tb_half, # type: ignore method="fir", iir_params=None, phase="zero", fir_window="hamming", fir_design="firwin", verbose=False, ) def process(self, data: np.ndarray) -> np.ndarray: if self.filter_bank is None: return data from mne.filter import _overlap_add_filter return _overlap_add_filter( x=data, h=self.filter_bank, n_fft=None, phase="zero", picks=None, n_jobs=1, copy=True, pad="reflect_limited", )