Skip to content

Commit

Permalink
refactoring for namespace/objects; small_ophys_inference shows a movi…
Browse files Browse the repository at this point in the history
…e instead of a single frame
  • Loading branch information
stevevanhooser committed May 4, 2024
1 parent ea0b27a commit eb3e202
Show file tree
Hide file tree
Showing 25 changed files with 731 additions and 1 deletion.
Binary file added +deepinterp/+Net/.ImageTimeSeries.m.swp
Binary file not shown.
137 changes: 137 additions & 0 deletions +deepinterp/+Net/ImageTimeSeries.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
classdef ImageTimeSeries < deepinterp.Net

properties (GetAccess=public, SetAccess=protected)
N % number of rows of inputs to the network
M % number of columns of inputs to the network
Npre % Number of input frames needed before the predicted frame
Npost % Number of input frames needed after the predicted frame
end


methods

function obj = ImageTimeSeries(command, options)
% OBJ = NET(COMMAND,...)
%
% Create a new deepinterp.Net.ImageTimeSeries object.
%
% deepinterp.Net.ImageTimeSeries objects remove independent
% frame-by-frame noise in image sequences that are usually time
% series. The network takes Npre video frames _before_ the frame
% to be denoised, and Npost video frames _after_ the frame to be
% denoised, to make a computational prediction of the denoised
% frame.
%
% COMMAND can be:
%
% 'New' : create a new empty network
% 'KerasFile': open and import a Keras file
%
% The function also accepts additional arguments as name/value pairs:
% --------------------------------------------------------------------
% | Parameter (default) | Description |
% |-------------------------|----------------------------------------|
% | file ('') | Filename to open if reading a file |
% | N (512) | Number of rows in images if new network|
% | M (512) | Number of columns in images for new net|
% | Npre (30) | Number of frames before prediction |
% | Npost (30) | Number of frames after prediction |
% |-------------------------|----------------------------------------|
%
% Example:
% sampleNet = deepinterp.sampleFile(...
% '2019_09_11_23_32_unet_single_1024_mean_absolute_error_Ai93-0450.h5');
% n = deepinterp.Net.ImageTimeSeries('KerasFile','file',sampleNet,...
% 'Npre',30,'Npost',30);
%
arguments
command (1,:) char {mustBeTextScalar}
options.file (1,:) char {mustBeFile}
options.N (1,1) double {mustBeInteger} = 512;
options.M (1,1) double {mustBeInteger} = 512;
options.Npre (1,1) double {mustBeInteger} = 30;
options.Npost (1,1) double {mustBeInteger} = 30;
end

superOptions = rmfield(options,{'N','M','Npre','Npost'});
superInputs = cat(2,{command},...
deepinterp.struct2namevaluepair(superOptions));
obj = [email protected](superInputs{:});
obj.Npre = options.Npre;
obj.Npost = options.Npost;

if strcmp(command,'New'),
disp(['In future we will build a net network here.']);
end;

obj = obj.SetInputSize();

end; % ImageTimeSeries - constructor

function obj = SetInputSize(obj)
% SETINPUTSIZE - set the input size properties
%
% OBJ = SETINPUTSIZE(OBJ)
%
% Sets the input size properties N and M, and verifies that
% the number of image sequences required in each input add up to
% obj.Npre and obj.Npost.
%
% The values can be examined by OBJ.N, OBJ.M.
%
if isempty(obj.network),
obj.N = NaN;
obj.M = NaN;
else,
inSz = obj.network.Layers(1).InputSize;
obj.N = inSz(1);
obj.M = inSz(2);
assert(inSz(3)==obj.Npre+obj.Npost,...
['Npre + Npost must add up to the total ' ...
'number of image inputs to ' ...
'deepinterp.net.ImageTimeSeries.']);
end;
end; % SetInputSize

function output = interp(obj, input, options)
% INTERP - apply DeepInterpolation to inputs
%
% OUTPUT = INTERP(DEEPINTERP_IMAGETIMESERIES_OBJ, INPUT)
%
% Apply the DeepInterpolation network to the INPUT data.
% INPUT should be an NxMxT matrix, where N and M match
% the size of the DEEPINTERP_IMAGETIMESERIES_OBJ.N and .M parameters,
% and T must be greater than the Npre + Npost property values of
% DEEPINTERP_IMAGETIMESERIES_OBJ. Interpolation will only be performed
% for frames greater than Npre and less than numel(T) - Npost. The first
% Npre frames and last Npost frames will be equal to the input.
%
arguments
obj (1,1)
input
options.progbar (1,1) logical = true;
end

output = input;
offsets = [-obj.Npre:-1 1:obj.Npost];
if options.progbar,
deepinterp.progressbar();
end;
totalWork = size(input,3)-(obj.Npost+obj.Npre);
for t = obj.Npre+1 : size(input,3) - obj.Npost,
if options.progbar,
deepinterp.progressbar((t-(obj.Npre+1))/totalWork);
end;
output(:,:,t) = predict(obj.network,input(:,:,offsets+t));
end;
if options.progbar,
progressbar(0.9999999999);
end;
end; % interp()

% INTERP - apply DeepInterpolation


end; % methods

end % classdef
76 changes: 76 additions & 0 deletions +deepinterp/Net.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
classdef Net

properties (SetAccess=protected)
network % Deep learning network
end


methods
function obj = Net(command,options)
% NET - create a new deepinterp.Net object
%
% OBJ = NET(COMMAND,...)
%
% Create a new deepinterp.Net object according to instructions.
% COMMAND can be:
%
% 'KerasFile': open and import a Keras file
% 'New' : create a new empty network
%
% The function also accepts additional arguments as name/value pairs:
% --------------------------------------------------------------------
% | Parameter (default) | Description |
% |-------------------------|----------------------------------------|
% | file ('') | Filename to open if reading a file |
% |-------------------------|----------------------------------------|
%
% Example:
% n = deepinterp.Net('KerasFile','file','myKerasFile.H5');
%
arguments
command (1,:) char {mustBeTextScalar}
options.file (1,:) char {mustBeFile}
end

switch (command),
case 'KerasFile',
obj.network=deepinterp.importKerasMAE(options.file);
case 'New',
% nothing to do
otherwise,
error(['Unknown command: ' command]);
end;

end; % Net constructor

function net_obj = train(net_obj, inputs, outputs)
% TRAIN - train a DeepInterpolation network
%
% NET_OBJ = TRAIN(NET_OBJ, INPUTS, OUTPUTS)
%
% Train a deepinterp.Net object with example
% INPUTS and OUTPUTS.
%
% Subclasses should describe the form of these inputs
% and outputs.
%

end; % train()

function output = interp(net_obj, inputs)
% INTERP - apply DeepInterpolation
%
% OUTPUT = INTERP(NET_OBJ, INPUT)
%
% Apply the DeepInterpolation network to the
% INPUT data.
%
% Subclasses should describe the form of the input
% and output.
%

end; % interp()
end;

end % classdef

19 changes: 19 additions & 0 deletions +deepinterp/importKerasMAE.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function net = importKerasMAE(kerasFile)
% IMPORTKERAS_MAE - import a Keras network, converting placeholder to mae
%
% NET = IMPORTKERASMAE(KERASFILE)
%
% Imports a Keras Deep Learning Network from the file KERASFILE.
% PLACEHOLDER layers are replaced with mean-average-error
% layers.
%

importednet = importKerasLayers(kerasFile,'ImportWeights',true);

placeholders = findPlaceholderLayers(importednet);

regressionnet = replaceLayer(importednet, placeholders.Name , ...
deepinterp.maeRegressionLayer);

net = assembleNetwork(regressionnet);

34 changes: 34 additions & 0 deletions +deepinterp/maeRegressionLayer.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
classdef maeRegressionLayer < nnet.layer.RegressionLayer
% Example custom regression layer with mean-absolute-error loss.

methods
function layer = maeRegressionLayer(name)
% layer = maeRegressionLayer(name) creates a
% mean-absolute-error regression layer and specifies the layer
% name.

arguments
name(1, 1) string = "MAE Regression"
end

% Set layer name.
layer.Name = name;

% Set layer description.
layer.Description = 'Mean absolute error';
end

function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the MAE loss between
% the predictions Y and the training targets T.

% Calculate MAE.
R = size(Y,3);
meanAbsoluteError = sum(abs(Y-T),3)/R;

% Take mean over mini-batch.
N = size(Y,4);
loss = sum(meanAbsoluteError)/N;
end
end
end
Loading

0 comments on commit eb3e202

Please sign in to comment.