Source code for analysis.RMAP

import numpy as np
from pathlib import PurePath, Path


# from numba import jit
import scipy.io as sio
import pandas as pd
import nibabel as nib
from matplotlib import pyplot as plt

from py_neuromodulation.analysis import reg_plot
from py_neuromodulation.utils.types import _PathLike
from py_neuromodulation import PYNM_DIR

LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL = [256, 385, 417, 447, 819, 914]
LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN = [
    1,
    8,
    16,
    33,
    34,
    35,
    36,
    37,
    51,
    75,
    77,
    78,
    99,
    109,
    115,
    136,
    155,
    170,
    210,
    215,
    243,
    352,
    359,
    361,
    415,
    416,
    422,
    529,
    567,
    569,
    622,
    623,
    625,
    627,
    632,
    633,
    634,
    635,
    639,
    640,
    641,
    643,
    644,
    650,
    661,
    663,
    667,
    683,
    684,
    685,
    704,
    708,
    722,
    839,
    840,
    905,
    993,
    1011,
]


[docs] class ConnectivityChannelSelector: def __init__( self, whole_brain_connectome: bool = True, func_connectivity: bool = True, ) -> None: """ConnectivityChannelSelector Parameters ---------- whole_brain_connectome : bool, optional if True a 1236 whole-brain point grid is chosen, if False, a 1025 point grid of the cortical hull is loaded, by default True func_connectivity : bool, optional if true, functional connectivity fMRI is loaded, if false structural dMRIby, default True """ self.connectome_name = self._get_connectome_name( whole_brain_connectome, func_connectivity ) self.whole_brain_connectome = whole_brain_connectome self.func_connectivity = func_connectivity self.PATH_CONN_DECODING = PYNM_DIR / "ConnectivityDecoding" if whole_brain_connectome: self.PATH_GRID = PurePath( self.PATH_CONN_DECODING, "mni_coords_whole_brain.mat", ) self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"] if not func_connectivity: # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN self.grid = np.delete( self.grid, LIST_STRUC_UNCONNECTED_GRIDPOINTS_WHOLEBRAIN, axis=0, ) else: self.PATH_GRID = PurePath( self.PATH_CONN_DECODING, "mni_coords_cortical_surface.mat", ) self.grid = sio.loadmat(self.PATH_GRID)["downsample_ctx"] if not func_connectivity: # reduce the grid to only valid points that are not in LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL self.grid = np.delete( self.grid, LIST_STRUC_UNCONNECTED_GRIDPOINTS_HULL, axis=0 ) if func_connectivity: self.RMAP_arr = nib.load( PurePath(self.PATH_CONN_DECODING, "RMAP_func_all.nii") ).get_fdata() else: self.RMAP_arr = nib.load( PurePath(self.PATH_CONN_DECODING, "RMAP_struc.nii") ).get_fdata() def _get_connectome_name(self, whole_brain_connectome: str, func_connectivity: str): connectome_name = "connectome_" if whole_brain_connectome: connectome_name += "whole_brain_" else: connectome_name += "hull_" if func_connectivity: connectome_name += "func" else: connectome_name += "struc" return connectome_name
[docs] def get_available_connectomes(self) -> list: """Return list of saved connectomes in the package folder/ConnectivityDecoding/connectome_folder/ folder. Returns ------- list_connectomes: list """ return list(Path(self.PATH_CONN_DECODING, "connectome_folder").iterdir())
[docs] def plot_grid(self) -> None: """Plot the loaded template grid that passed coordinates are matched to.""" fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.scatter(self.grid[:, 0], self.grid[:, 1], self.grid[:, 2], s=50, alpha=0.2) plt.show()
[docs] def get_closest_node(self, coord: list | np.ndarray) -> tuple[list, list]: """Given a list or np.array of coordinates, return the closest nodes in the grid and their indices. Parameters ---------- coord : np.ndarray MNI coordinates with shape (num_channels, 3) Returns ------- Tuple[list, list] Grid coordinates, grid indices """ idx_ = [] for c in coord: dist = np.linalg.norm(self.grid - c, axis=1) idx_.append(np.argmin(dist)) return [self.grid[idx] for idx in idx_], idx_
[docs] def get_rmap_correlations( self, fps: list | np.ndarray, RMAP_use: np.ndarray | None = None ) -> list: """Calculate correlations of passed fingerprints with the RMAP Parameters ---------- fps : Union[list, np.array] List of fingerprints RMAP_use : np.ndarray, optional Passed RMAP, by default None Returns ------- List correlation values """ RMAP_ = self.RMAP_arr if RMAP_use is None else RMAP_use RMAP_ = RMAP_.flatten() corrs = [] for fp in fps: corrs.append(np.corrcoef(RMAP_, fp.flatten())[0][1]) return corrs
[docs] def load_connectome( self, whole_brain_connectome: bool | None = None, func_connectivity: bool | None = None, ) -> None: """Load connectome, if not available download connectome from Zenodo. Parameters ---------- whole_brain_connectome : bool, optional if true whole brain connectome if false cortical hull grid connectome, by default None func_connectivity : bool, optional if true fMRI if false dMRI, by default None """ if whole_brain_connectome is not None: self.whole_brain_connectome = whole_brain_connectome if func_connectivity is not None: self.func_connectivity = func_connectivity self.connectome_name = self._get_connectome_name( self.whole_brain_connectome, self.func_connectivity ) PATH_CONNECTOME = Path( self.PATH_CONN_DECODING, "connectome_folder", self.connectome_name + ".mat", ) if not PATH_CONNECTOME.exists(): user_input = input( "Do you want to download the connectome? (yes/no): " ).lower() if user_input == "yes": self._download_connectome() elif user_input == "no": print("Connectome missing, has to be downloaded") self.connectome = sio.loadmat(PATH_CONNECTOME)
def get_grid_fingerprints(self, grid_idx: list | np.ndarray) -> list: return [self.connectome[str(grid_idx)] for grid_idx in grid_idx] def download_connectome( self, ): from urllib.request import urlretrieve # download the connectome from the Zenodo API print("Downloading the connectome...") record_id = "10804702" file_name = self.connectome_name filepath = Path(self.PATH_CONN_DECODING, "connectome_folder") filepath.mkdir(parents=True, exist_ok=True) urlretrieve( f"https://zenodo.org/api/records/{record_id}/files/{file_name}/content", filepath / f"{self.connectome_name}.mat", )
[docs] class RMAPCross_Val_ChannelSelector: def __init__(self) -> None: pass
[docs] def load_fingerprint(self, path_nii) -> None: """Return Nifti fingerprint""" epi_img = nib.load(path_nii) self.affine = epi_img.affine fp = epi_img.get_fdata() return fp
def load_all_fingerprints(self, path_dir: str, cond_str: str = "_AvgR_Fz.nii"): if cond_str is not None: l_fps = list(filter(lambda k: cond_str in str(k), Path(path_dir).iterdir())) else: l_fps = list(Path(path_dir).iterdir()) return l_fps, [self.load_fingerprint(PurePath(path_dir, f)) for f in l_fps] def get_fingerprints_from_path_with_cond( self, path_dir: _PathLike, str_to_omit: str = "", str_to_keep: str = "", keep: bool = True, ) -> tuple[list, list]: l_fps = [] if keep and str_to_keep: l_fps = list( filter( lambda k: "_AvgR_Fz.nii" in str(k) and str_to_keep in str(k), Path(path_dir).iterdir(), ) ) elif not keep and str_to_omit: l_fps = list( filter( lambda k: "_AvgR_Fz.nii" in str(k) and str_to_omit not in str(k), Path(path_dir).iterdir(), ) ) return l_fps, [self.load_fingerprint(PurePath(path_dir, f)) for f in l_fps] @staticmethod def save_Nii( fp: np.ndarray, affine: np.ndarray, name: str = "img.nii", reshape: bool = True, ): if reshape: fp = np.reshape(fp, (91, 109, 91), order="C") img = nib.nifti1.Nifti1Image(fp, affine=affine) nib.save(img, name) def get_RMAP(self, X: np.ndarray, y: np.ndarray): # faster than calculate_RMap_numba # https://stackoverflow.com/questions/71252740/correlating-an-array-row-wise-with-a-vector/71253141#71253141 r = ( len(y) * np.sum(X * y[None, :], axis=-1) - (np.sum(X, axis=-1) * np.sum(y)) ) / ( np.sqrt( (len(y) * np.sum(X**2, axis=-1) - np.sum(X, axis=-1) ** 2) * (len(y) * np.sum(y**2) - np.sum(y) ** 2) ) ) return r @staticmethod # @jit(nopython=True) def calculate_RMap_numba(fp, performances): # The RMap also needs performances; for every fingerprint / channel # Save the corresponding performance # for every voxel; correlate it with performances arr = fp[0].flatten() NUM_VOXELS = arr.shape[0] LEN_FPS = len(fp) fp_arr = np.empty((NUM_VOXELS, LEN_FPS)) for fp_idx, fp_ in enumerate(fp): fp_arr[:, fp_idx] = fp_.flatten() RMAP = np.zeros(NUM_VOXELS) for voxel in range(NUM_VOXELS): corr_val = np.corrcoef(fp_arr[voxel, :], performances)[0][1] RMAP[voxel] = corr_val return RMAP @staticmethod # @jit(nopython=True) def get_corr_numba(fp, fp_test): val = np.corrcoef(fp_test, fp)[0][1] return val def leave_one_ch_out_cv(self, l_fps_names: list, l_fps_dat: list, l_per: list): # l_fps_dat is not flattened per_left_out = [] per_predict = [] for idx_left_out, f_left_out in enumerate(l_fps_names): # print(idx_left_out) l_cv = l_fps_dat.copy() per_cv = l_per.copy() l_cv.pop(idx_left_out) per_cv.pop(idx_left_out) conn_arr = [] for f in l_cv: conn_arr.append(f.flatten()) conn_arr = np.array(conn_arr) rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv))) per_predict.append( np.nan_to_num( self.get_corr_numba(rmap_cv, l_fps_dat[idx_left_out].flatten()) ) ) per_left_out.append(l_per[idx_left_out]) return per_left_out, per_predict def leave_one_sub_out_cv( self, l_fps_names: list, l_fps_dat: list, l_per: list, sub_list: list ): # l_fps_dat assume non flatted arrays # each fp including the sub_list string will be iteratively removed for test set per_predict = [] per_left_out = [] for subject_test in sub_list: # print(subject_test) idx_test = [idx for idx, f in enumerate(l_fps_names) if subject_test in f] idx_train = [ idx for idx, f in enumerate(l_fps_names) if subject_test not in f ] l_cv = list(np.array(l_fps_dat)[idx_train]) per_cv = list(np.array(l_per)[idx_train]) conn_arr = [] for f in l_cv: conn_arr.append(f.flatten()) conn_arr = np.array(conn_arr) rmap_cv = np.nan_to_num(self.get_RMAP(conn_arr.T, np.array(per_cv))) for idx in idx_test: per_predict.append( np.nan_to_num( self.get_corr_numba(rmap_cv, l_fps_dat[idx].flatten()) ) ) per_left_out.append(l_per[idx]) return per_left_out, per_predict def get_highest_corr_sub_ch( self, cohort_test: str, sub_test: str, ch_test: str, cohorts_train: dict, path_dir: str = r"C:\Users\ICN_admin\OneDrive - Charité - Universitätsmedizin Berlin\Connectomics\DecodingToolbox_BerlinPittsburgh_Beijing\functional_connectivity", ): fp_test = self.get_fingerprints_from_path_with_cond( path_dir=path_dir, str_to_keep=f"{cohort_test}_{sub_test}_ROI_{ch_test}", keep=True, )[1][ 0 ].flatten() # index 1 for getting the array, 0 for the list fp that was found fp_pairs = [] for cohort in cohorts_train.keys(): for sub in cohorts_train[cohort]: fps_name, fps = self.get_fingerprints_from_path_with_cond( path_dir=path_dir, str_to_keep=f"{cohort}_{sub}_ROI", keep=True, ) for fp, fp_name in zip(fps, fps_name): ch = fp_name[fp_name.find("ROI") + 4 : fp_name.find("func") - 1] corr_val = self.get_corr_numba(fp_test, fp) fp_pairs.append([cohort, sub, ch, corr_val]) idx_max = np.argmax(np.array(fp_pairs)[:, 3]) return fp_pairs[idx_max][0:3] def plot_performance_prediction_correlation( per_left_out, per_predict, out_path_save: str | None = None ): df_plt_corr = pd.DataFrame() df_plt_corr["test_performance"] = per_left_out df_plt_corr["struct_conn_predict"] = ( per_predict # change "struct" with "funct" for functional connectivity ) reg_plot( x_col="test_performance", y_col="struct_conn_predict", data=df_plt_corr, out_path_save=out_path_save, )