import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
import seaborn as sb
from pathlib import PurePath
from py_neuromodulation import logger, PYNM_DIR
from py_neuromodulation.utils.types import _PathLike
def plot_df_subjects(
df,
x_col="sub",
y_col="performance_test",
hue=None,
title="channel specific performances",
PATH_SAVE: _PathLike = "",
figsize_tuple: tuple[float, float] = (5, 3),
):
alpha_box = 0.4
plt.figure(figsize=figsize_tuple, dpi=300)
sb.boxplot(
x=x_col,
y=y_col,
hue=hue,
data=df,
palette="viridis",
showmeans=False,
boxprops=dict(alpha=alpha_box),
showcaps=True,
showbox=True,
showfliers=False,
notch=False,
whiskerprops={"linewidth": 2, "zorder": 10, "alpha": alpha_box},
capprops={"alpha": alpha_box},
medianprops=dict(linestyle="-", linewidth=5, color="gray", alpha=alpha_box),
)
ax = sb.stripplot(
x=x_col,
y=y_col,
hue=hue,
data=df,
palette="viridis",
dodge=True,
s=5,
)
if hue is not None:
n_hues = df[hue].nunique()
handles, labels = ax.get_legend_handles_labels()
plt.legend(
handles[0:n_hues],
labels[0:n_hues],
bbox_to_anchor=(1.05, 1),
loc=2,
title=hue,
borderaxespad=0.0,
)
plt.title(title)
plt.ylabel(y_col)
plt.xticks(rotation=90)
if PATH_SAVE:
plt.savefig(
PATH_SAVE,
bbox_inches="tight",
)
# plt.show()
return plt.gca()
def plot_epoch(
X_epoch: np.ndarray,
y_epoch: np.ndarray,
feature_names: list,
z_score: bool | None = None,
epoch_len: int = 4,
sfreq: int = 10,
str_title: str = "",
str_label: str = "",
ytick_labelsize: float | None = None,
):
from scipy.stats import zscore
if z_score is None:
X_epoch = zscore(
np.nan_to_num(np.nanmean(np.squeeze(X_epoch), axis=0)),
axis=0,
nan_policy="omit",
).T
y_epoch = np.stack([np.array(y_epoch)])
plt.figure(figsize=(6, 6))
plt.subplot(211)
plt.imshow(X_epoch, aspect="auto")
plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
plt.xticks(
np.arange(0, X_epoch.shape[1], 1),
np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
rotation=90,
)
plt.gca().invert_yaxis()
plt.xlabel("Time [s]")
plt.title(str_title)
plt.subplot(212)
for i in range(y_epoch.shape[0]):
plt.plot(y_epoch[i, :], color="black", alpha=0.4)
plt.plot(
y_epoch.mean(axis=0),
color="black",
alpha=1,
linewidth=3.0,
label="mean target",
)
plt.legend()
plt.ylabel("Target")
plt.title(str_label)
plt.xticks(
np.arange(0, X_epoch.shape[1], 1),
np.round(np.arange(-epoch_len / 2, epoch_len / 2, 1 / sfreq), 2),
rotation=90,
)
plt.xlabel("Time [s]")
plt.tight_layout()
def reg_plot(
x_col: str, y_col: str, data: pd.DataFrame, out_path_save: str | None = None
):
from py_neuromodulation.analysis.stats import permutationTestSpearmansRho
plt.figure(figsize=(4, 4), dpi=300)
rho, p = permutationTestSpearmansRho(
data[x_col],
data[y_col],
False,
"R^2",
5000,
)
sb.regplot(x=x_col, y=y_col, data=data)
plt.title(f"{y_col}~{x_col} p={np.round(p, 2)} rho={np.round(rho, 2)}")
if out_path_save is not None:
plt.savefig(
out_path_save,
bbox_inches="tight",
)
def plot_bar_performance_per_channel(
ch_names,
performances: dict,
PATH_OUT: _PathLike,
sub: str | None = None,
save_str: str = "ch_comp_bar_plt.png",
performance_metric: str = "Balanced Accuracy",
):
"""
performances dict is output of ml_decode
"""
plt.figure(figsize=(4, 3), dpi=300)
if sub is None:
sub = list(performances.keys())[0]
plt.bar(
np.arange(len(ch_names)),
[performances[sub][p]["performance_test"] for p in performances[sub]],
)
plt.xticks(np.arange(len(ch_names)), ch_names, rotation=90)
plt.xlabel("channels")
plt.ylabel(performance_metric)
plt.savefig(
PurePath(PATH_OUT, save_str),
bbox_inches="tight",
)
plt.close()
def plot_corr_matrix(
feature: pd.DataFrame,
feature_file: _PathLike = "",
ch_name: str = "",
feature_names: list[str] = [],
show_plot=True,
OUT_PATH: _PathLike = "",
feature_name_plt="Features_corr_matr",
save_plot: bool = False,
save_plot_name: str = "",
figsize: tuple[float, float] = (7, 7),
title: str = "",
cbar_vmin: float = -1,
cbar_vmax: float = 1.0,
):
# cut out channel name for each column
if not ch_name:
feature_col_name = [
i[len(ch_name) + 1 :] for i in feature_names if ch_name in i
]
else:
feature_col_name = feature.columns
plt.figure(figsize=figsize)
if (
len(feature_names) > 0
): # Checking length to accomodate for tests passing a pandas Index
corr = feature[feature_names].corr()
else:
corr = feature.corr()
sb.heatmap(
corr,
xticklabels=feature_col_name,
yticklabels=feature_col_name,
vmin=cbar_vmin,
vmax=cbar_vmax,
cmap="viridis",
)
if not title:
if ch_name:
plt.title("Correlation matrix features channel: " + str(ch_name))
else:
plt.title("Correlation matrix")
else:
plt.title(title)
# if len(feature_col_name) > 50:
# plt.xticks([])
# plt.yticks([])
if save_plot:
plt_path = (
PurePath(OUT_PATH, save_plot_name)
if save_plot_name
else get_plt_path(
OUT_PATH=OUT_PATH,
feature_file=feature_file,
ch_name=ch_name,
str_plt_type=feature_name_plt,
feature_name="_".join(feature_names),
)
)
plt.savefig(plt_path, bbox_inches="tight")
logger.info(f"Correlation matrix figure saved to {plt_path}")
if not show_plot:
plt.close()
plt.tight_layout()
return plt.gca()
def plot_feature_series_time(features) -> None:
plt.imshow(features.T, aspect="auto")
def get_plt_path(
OUT_PATH: _PathLike = "",
feature_file: str = "",
ch_name: str = "",
str_plt_type: str = "",
feature_name: str = "",
) -> _PathLike:
"""[summary]
Parameters
----------
OUT_PATH : str, optional
folder of preprocessed runs, by default None
feature_file : str, optional
run_name, by default None
ch_name : str, optional
ch_name, by default None
str_plt_type : str, optional
type of plot, e.g. mov_avg_feature or corr_matr, by default None
feature_name : str, optional
e.g. bandpower, stft, sharpwave_prominence, by default None
"""
filename = (
str_plt_type
+ (("_ch_" + ch_name) if ch_name else "")
+ (("_" + feature_name) if feature_name else "")
+ ".png"
)
return PurePath(OUT_PATH, feature_file, filename)
def plot_epochs_avg(
X_epoch: np.ndarray,
y_epoch: np.ndarray,
epoch_len: int,
sfreq: int,
feature_names: list[str] = [],
feature_str_add: str = "",
cut_ch_name_cols: bool = True,
ch_name: str = "",
label_name: str = "",
normalize_data: bool = True,
show_plot: bool = True,
save: bool = False,
OUT_PATH: _PathLike = "",
feature_file: str = "",
str_title: str = "Movement aligned features",
ytick_labelsize=None,
figsize_x: float = 8,
figsize_y: float = 8,
) -> None:
from scipy.stats import zscore
# cut channel name of for axis + "_" for more dense plot
if not feature_names:
if cut_ch_name_cols and None not in (ch_name, feature_names):
feature_names = [
i[len(ch_name) + 1 :] for i in list(feature_names) if ch_name in i
]
if normalize_data:
X_epoch_mean = zscore(
np.nanmean(np.squeeze(X_epoch), axis=0), axis=0, nan_policy="omit"
).T
else:
X_epoch_mean = np.nanmean(np.squeeze(X_epoch), axis=0).T
if len(X_epoch_mean.shape) == 1:
X_epoch_mean = np.expand_dims(X_epoch_mean, axis=0)
plt.figure(figsize=(figsize_x, figsize_y))
gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1])
plt.subplot(gs[0])
plt.imshow(X_epoch_mean, aspect="auto")
plt.yticks(np.arange(0, len(feature_names), 1), feature_names, size=ytick_labelsize)
plt.xticks(
np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
rotation=90,
)
plt.xlabel("Time [s]")
str_title = str_title
if ch_name:
str_title += f" channel: {ch_name}"
plt.title(str_title)
plt.subplot(gs[1])
for i in range(y_epoch.shape[0]):
plt.plot(y_epoch[i, :], color="black", alpha=0.4)
plt.plot(
y_epoch.mean(axis=0),
color="black",
alpha=1,
linewidth=3.0,
label="mean target",
)
plt.legend()
plt.ylabel("Target")
plt.title(label_name)
plt.xticks(
np.arange(0, X_epoch.shape[1], int(X_epoch.shape[1] / 10)),
np.round(np.arange(-epoch_len / 2, epoch_len / 2, epoch_len / 10), 2),
rotation=90,
)
plt.xlabel("Time [s]")
plt.tight_layout()
if save:
plt_path = get_plt_path(
OUT_PATH,
feature_file,
ch_name,
str_plt_type="MOV_aligned_features",
feature_name=feature_str_add,
)
plt.savefig(plt_path, bbox_inches="tight")
logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
if not show_plot:
plt.close()
def plot_grid_elec_3d(
cortex_grid: np.ndarray | None = None,
ecog_strip: np.ndarray | None = None,
grid_color: np.ndarray | None = None,
strip_color: np.ndarray | None = None,
):
ax = plt.axes(projection="3d")
if cortex_grid is not None:
grid_color = np.ones(cortex_grid.shape[0]) if grid_color is None else grid_color
_ = ax.scatter3D(
cortex_grid[:, 0],
cortex_grid[:, 1],
cortex_grid[:, 2],
c=grid_color,
s=300,
alpha=0.8,
cmap="viridis",
)
if ecog_strip is not None:
strip_color = (
np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
)
_ = ax.scatter(
ecog_strip[:, 0],
ecog_strip[:, 1],
ecog_strip[:, 2],
c=strip_color,
s=500, # Bug? Third argument is s, what is this value?
alpha=0.8,
cmap="gray",
marker="o",
)
def plot_all_features(
df: pd.DataFrame,
time_limit_low_s: float | None = None,
time_limit_high_s: float | None = None,
normalize: bool = True,
ytick_labelsize: int = 4,
clim_low: float | None = None,
clim_high: float | None = None,
save: bool = False,
title="all_feature_plt.pdf",
OUT_PATH: _PathLike = "",
feature_file: str = "",
):
from scipy.stats import zscore
if time_limit_high_s is not None:
df = df[df["time"] < time_limit_high_s * 1000]
if time_limit_low_s is not None:
df = df[df["time"] > time_limit_low_s * 1000]
cols_plt = [c for c in df.columns if c != "time"]
if normalize:
data_plt = zscore(df[cols_plt], nan_policy="omit")
else:
data_plt = df[cols_plt]
plt.figure() # figsize=(7, 5), dpi=300
plt.imshow(data_plt.T, aspect="auto")
plt.xlabel("Time [s]")
plt.ylabel("Feature Names")
plt.yticks(np.arange(len(cols_plt)), cols_plt, size=ytick_labelsize)
tick_num = np.arange(0, df.shape[0], int(df.shape[0] / 10))
tick_labels = np.array(np.rint(df["time"].iloc[tick_num] / 1000), dtype=int)
plt.xticks(tick_num, tick_labels)
plt.title(f"Feature Plot {feature_file}")
if clim_low is not None:
plt.clim(vmin=clim_low)
if clim_high is not None:
plt.clim(vmax=clim_high)
plt.colorbar()
plt.tight_layout()
if save:
plt_path = PurePath(OUT_PATH, feature_file, title)
plt.savefig(plt_path, bbox_inches="tight")
def read_plot_modules(
PATH_PLOT: _PathLike = PYNM_DIR / "plots",
):
"""Read required .mat files for plotting
Parameters
----------
PATH_PLOT : regexp, optional
path to plotting files, by default
"""
from py_neuromodulation.utils.io import loadmat
faces = loadmat(PurePath(PATH_PLOT, "faces.mat"))
vertices = loadmat(PurePath(PATH_PLOT, "Vertices.mat"))
grid = loadmat(PurePath(PATH_PLOT, "grid.mat"))["grid"]
stn_surf = loadmat(PurePath(PATH_PLOT, "STN_surf.mat"))
x_ver = stn_surf["vertices"][::2, 0]
y_ver = stn_surf["vertices"][::2, 1]
x_ecog = vertices["Vertices"][::1, 0]
y_ecog = vertices["Vertices"][::1, 1]
z_ecog = vertices["Vertices"][::1, 2]
x_stn = stn_surf["vertices"][::1, 0]
y_stn = stn_surf["vertices"][::1, 1]
z_stn = stn_surf["vertices"][::1, 2]
return (
faces,
vertices,
grid,
stn_surf,
x_ver,
y_ver,
x_ecog,
y_ecog,
z_ecog,
x_stn,
y_stn,
z_stn,
)
[docs]
class NM_Plot:
def __init__(
self,
ecog_strip: np.ndarray | None = None,
grid_cortex: np.ndarray | None = None,
grid_subcortex: np.ndarray | None = None,
sess_right: bool | None = False,
proj_matrix_cortex: np.ndarray | None = None,
) -> None:
self.grid_cortex = grid_cortex
self.grid_subcortex = grid_subcortex
self.ecog_strip = ecog_strip
self.sess_right = sess_right
self.proj_matrix_cortex = proj_matrix_cortex
(
self.faces,
self.vertices,
self.grid,
self.stn_surf,
self.x_ver,
self.y_ver,
self.x_ecog,
self.y_ecog,
self.z_ecog,
self.x_stn,
self.y_stn,
self.z_stn,
) = read_plot_modules()
def plot_grid_elec_3d(self) -> None:
plot_grid_elec_3d(np.array(self.grid_cortex), np.array(self.ecog_strip))
[docs]
def plot_cortex(
self,
grid_cortex: np.ndarray | pd.DataFrame | None = None,
grid_color: np.ndarray | None = None,
ecog_strip: np.ndarray | None = None,
strip_color: np.ndarray | None = None,
sess_right: bool | None = None,
save: bool = False,
OUT_PATH: _PathLike = "",
feature_file: str = "",
feature_str_add: str = "",
show_plot: bool = True,
title: str = "Cortical grid",
set_clim: bool = True,
lower_clim: float = 0.5,
upper_clim: float = 0.7,
cbar_label: str = "Balanced Accuracy",
):
"""Plot MNI brain including selected MNI cortical projection grid + used strip ECoG electrodes
Colorcoded by grid_color
"""
if grid_cortex is None:
if type(self.grid_cortex) is pd.DataFrame:
grid_cortex = np.array(self.grid_cortex)
else:
grid_cortex = self.grid_cortex
if ecog_strip is None:
ecog_strip = self.ecog_strip
if sess_right:
grid_cortex[0, :] = grid_cortex[0, :] * -1 # type: ignore # Handled above
fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
axes.scatter(self.x_ecog, self.y_ecog, c="gray", s=0.01)
axes.axes.set_aspect("equal", anchor="C")
if grid_cortex is not None:
grid_color = (
np.ones(grid_cortex.shape[0]) if grid_color is None else grid_color
)
pos_ecog = axes.scatter(
grid_cortex[:, 0],
grid_cortex[:, 1],
c=grid_color,
s=150,
alpha=0.8,
cmap="viridis",
label="grid points",
)
if set_clim:
pos_ecog.set_clim(lower_clim, upper_clim)
if ecog_strip is not None:
strip_color = (
np.ones(ecog_strip.shape[0]) if strip_color is None else strip_color
)
pos_ecog = axes.scatter(
ecog_strip[:, 0],
ecog_strip[:, 1],
c=strip_color,
s=400,
alpha=0.8,
cmap="viridis",
marker="x",
label="ecog electrode",
)
plt.axis("off")
plt.legend()
plt.title(title)
if set_clim:
pos_ecog.set_clim(lower_clim, upper_clim)
cbar = fig.colorbar(pos_ecog)
cbar.set_label(cbar_label)
if save:
plt_path = get_plt_path(
OUT_PATH,
feature_file,
str_plt_type="PLOT_CORTEX",
feature_name=feature_str_add,
)
plt.savefig(plt_path, bbox_inches="tight")
logger.info(f"Feature epoch average figure saved to: {str(plt_path)}")
if not show_plot:
plt.close()