Skip to content

Commit

Permalink
Merge pull request #42 from eminSerin/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
eminSerin authored Sep 5, 2024
2 parents b804e13 + e9492a9 commit 9bea0ef
Show file tree
Hide file tree
Showing 16 changed files with 454 additions and 31 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Operating system files
.DS_Store
*.DS_Store
Thumbs.db
.AppleDouble
.LSOverride
Expand All @@ -14,3 +14,5 @@ src/ext/BrainNetViewer_20181219/tmp.node
*.asv
src/.DS_Store
test/.DS_Store
src/.DS_Store

194 changes: 194 additions & 0 deletions TangentSpace.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
classdef TangentSpace < handle
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TangentSpace Transforms positive definite covariance matrices
% into tangent space. Transforming covariance matrices
% (or positive definite correlation matrices) into tangent space
% has been shown to provide significantly higher
% prediction performance (Dadi et al., 2019; Pervaiz et al., 2020).
%
%
%
% Example:
% vectorizer = TangentSpace;
% edgeMatrix = vectorizer.transform(data);
%
% References:
% Dadi, Kamalaker, et al. "Benchmarking functional connectome-based
% predictive models for resting-state fMRI."
% NeuroImage 192 (2019): 115-134.
% Pervaiz, Usama, et al. "Optimising network modelling
% methods for fMRI." Neuroimage 211 (2020): 116604.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

properties (Access = private)
dims % dimensions of the data
whitening % whitening matrix
ref_matrix % reference matrix
end

properties (Access = public)
ref
end

methods (Access = private)
function [] = check_positive_definite(~, X)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Checks whether input matrices are positive definite.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
x_dim = size(X);
n_dims = numel(x_dim);
assert(ismember(n_dims, [2,3], 'Input matrix must be 2D or 3D!'));
if numel(x_dim) == 3
for s = 1: x_dim(3)
assert(is_positive_definite(X(:,:,s)),...
'Input matrix is not positive definite!')
end
elseif numel(x_dim) == 2
assert(is_positive_definite(X),...
'Input matrix is not positive definite!')
end
end

function [] = map_shrinkage(obj, X)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Applies shrinkage each covariance matrix in X dataset.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
switch obj.ref
case 'euclidian'
% Takes euclidian mean of input matrix.
obj.ref_matrix = squeeze(mean(X, 3));
case 'log_euclidian'
% Computes log euclidian mean of input matrix.
obj.ref_matrix = exp(mean(log(X), 3));
end
end

function symMat = form_symmetric_matrix(~, X, func)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Constructs a symmetric matrix from a given X matrix
% using eigenvalues and eigenvectors computed from the input
% matrix. While constructing, it also applies a custom
% function to eigenvalues to transform the matrix space.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if nargin < 2
func = @(x) 1./x;
end

% Eigendecomposition
[EV,DV] = eig(X);
DV = func(DV); % transform eigenvalues using given function.
DV(isinf(DV)) = 0;

% form matrix again.
symMat = EV*DV*EV';
end
end

methods
function obj = TangentSpace(varargin)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Initializes class.
%
% Args:
% ref: Method to compute mean reference matrix
% (default = 'euclidian').
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Default parameters.
defaultVals.ref = 'euclidian';
refOptions = {'euclidian', 'log_euclidian'};

% Input Parser
validationRef = @(x) any(validatestring(x,refOptions));
p = inputParser(); p.PartialMatching = 0; % deactivate partial matching.
addParameter(p,'ref',defaultVals.ref,validationRef);

% Parse inputs and store into the object.
parse(p,varargin{:});
obj.ref = p.Results.ref;
obj.dims = [];
end

function obj = fit(obj,X)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Fits the tangent space transformer to the given X matrix.
%
% Args:
% X: 3D (nodes x nodes x subject) input matrix.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
obj.check_positive_definite(X);
obj.ref_matrix = map_shrinkage(X); % compute reference matrix.
obj.whitening = form_symmetric_matrix(obj, X);
end

function trans_mat = transform(obj,X)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Projects the input covariance matrices into tangent space
% using computed reference mean matrix.
%
% Args:
% X: 3D (nodes x nodes x subject) or 2D (nodes x nodes)
% input matrix.
%
% Output:
% trans_mat: Transformed matrix.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
obj.check_positive_definite(X);
x_dim = size(X);
n_dims = numel(x_dim);
assert(ismember(n_dims, [2,3], 'Input matrix must be 2D or 3D!'));
if n_dims == 3
nSub = x_dim(3);
trans_mat = zeros(x_dim);
elseif n_dims == 2
nSub = 1;
trans_mat = zeros([1; x_dim(:)]');
end

for s = 1: nSub
trans_mat(s,:,:) = obj.form_symmetric_matrix(...
obj.whitening * X(s,:,:) * obj.whitening, log);
end
trans_mat = squeeze(trans_mat);
end

function untrans_mat = inverse_transform(obj, X)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Inverse transform from tangent space to covariance matrix
% (Riemannian space).
%
% Args:
% X: Tangent transformed input matrix.
%
% Output:
% untrans_mat: Inverse transformed, covariance matrices.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
sqrt_whitening = obj.form_symmetric_matrix(obj.refMat, sqrt);
x_dim = size(X);
n_dims = numel(x_dim);
assert(ismember(n_dims, [2,3], 'Input matrix must be 2D or 3D!'));
if n_dims == 3
nSub = x_dim(3);
untrans_mat = zeros(x_dim);
elseif n_dims == 2
nSub = 1;
untrans_mat = zeros([1; x_dim(:)]');
end

for s = 1: nSub
untrans_mat(:,:,1) =...
sqrt_whitening * obj.form_symmetric_matrix(X(:,:,s), exp) *...
sqrt_whitening;
end
untrans_mat = squeeze(untrans_mat);
end

end
end



Binary file modified docs/MANUAL.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions src/helper/generate_randomStream.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
% generate_randomStream generates random stream for loop.

if randSeed ~= -1 % -1 refers to random shuffle.
rng(randSeed);
set_seed(randSeed);
else
rng('shuffle');
set_seed('shuffle');
end
rndSeeds = randi(1e+9, iter, 1);
end
Expand Down
2 changes: 1 addition & 1 deletion src/helper/get_version.m
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
function version = get_version()
version = '1.0.0-beta.12';
version = '1.0.0-beta.13';
end
18 changes: 18 additions & 0 deletions src/helper/set_seed.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
function set_seed(seed)
% Get MATLAB release information
releaseInfo = matlabRelease();
releaseName = char(releaseInfo.Release);

% Extract the year and letter from the release name
releaseYear = str2double(releaseName(2:5));
releaseLetter = releaseName(6);

% Check if the release is higher than R2023b
if releaseYear > 2023 || (releaseYear == 2023 && releaseLetter >= 'b')
% MATLAB release is higher than R2023b
rng(seed, 'twister');
else
% MATLAB release is R2023b or lower
rng(seed);
end
end
2 changes: 1 addition & 1 deletion src/model_selection/searching/private/get_searchInputs.m
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@

% Set random state.
if p.Results.randomState
rng(p.Results.randomState);
set_seed(p.Results.randomState);
end

end
Expand Down
12 changes: 6 additions & 6 deletions src/run_NBSPredict.m
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
if cNBSPredict.parameter.numCores > 1
% Run parallelly.
parfor r = 1: repCViter
rng(rndSeeds(r));
set_seed(rndSeeds(r));
[repCVscore(r),edgeWeight(r,:,:),...
truePredLabels(r,:,:),stability(r),...
bestParams(r,:), corrXy{r, :}] = outerFold(cNBSPredict);
Expand All @@ -84,7 +84,7 @@
else
% Run sequentially.
for r = 1: repCViter
rng(rndSeeds(r));
set_seed(rndSeeds(r));
[repCVscore(r),edgeWeight(r,:,:),...
truePredLabels(r,:,:),stability(r),...
bestParams(r,:), corrXy{r, :}] = outerFold(cNBSPredict);
Expand Down Expand Up @@ -371,9 +371,9 @@

% Random Seed
if NBSPredict.parameter.randSeed ~= -1 % -1 refers to random shuffle.
rng(NBSPredict.parameter.randSeed);
set_seed(NBSPredict.parameter.randSeed);
else
rng('shuffle');
set_seed('shuffle');
end
rndSeeds = generate_randomStream(randi(1e+9), permIter);

Expand All @@ -391,7 +391,7 @@
fprintf(permMsg)
end
parfor p = 1: permIter
rng(rndSeeds(p));
set_seed(rndSeeds(p));
permNBSPredict = NBSPredict;
permNBSPredict.data.y = permNBSPredict.data.y(randperm(nSub), :);
[permCVscore(p+1),~, ~, ~] = outerFold(permNBSPredict);
Expand All @@ -403,7 +403,7 @@
permProg = CmdProgress(permMsg, permIter);
end
for p = 1: permIter
rng(rndSeeds(p));
set_seed(rndSeeds(p));
permNBSPredict = NBSPredict;
permNBSPredict.data.y = permNBSPredict.data.y(randperm(nSub), :);
[permCVscore(p+1),~, ~, ~] = outerFold(permNBSPredict);
Expand Down
Loading

0 comments on commit 9bea0ef

Please sign in to comment.