-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactoring for namespace/objects; small_ophys_inference shows a movi…
…e instead of a single frame
- Loading branch information
1 parent
ea0b27a
commit eb3e202
Showing
25 changed files
with
731 additions
and
1 deletion.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
classdef ImageTimeSeries < deepinterp.Net | ||
|
||
properties (GetAccess=public, SetAccess=protected) | ||
N % number of rows of inputs to the network | ||
M % number of columns of inputs to the network | ||
Npre % Number of input frames needed before the predicted frame | ||
Npost % Number of input frames needed after the predicted frame | ||
end | ||
|
||
|
||
methods | ||
|
||
function obj = ImageTimeSeries(command, options) | ||
% OBJ = NET(COMMAND,...) | ||
% | ||
% Create a new deepinterp.Net.ImageTimeSeries object. | ||
% | ||
% deepinterp.Net.ImageTimeSeries objects remove independent | ||
% frame-by-frame noise in image sequences that are usually time | ||
% series. The network takes Npre video frames _before_ the frame | ||
% to be denoised, and Npost video frames _after_ the frame to be | ||
% denoised, to make a computational prediction of the denoised | ||
% frame. | ||
% | ||
% COMMAND can be: | ||
% | ||
% 'New' : create a new empty network | ||
% 'KerasFile': open and import a Keras file | ||
% | ||
% The function also accepts additional arguments as name/value pairs: | ||
% -------------------------------------------------------------------- | ||
% | Parameter (default) | Description | | ||
% |-------------------------|----------------------------------------| | ||
% | file ('') | Filename to open if reading a file | | ||
% | N (512) | Number of rows in images if new network| | ||
% | M (512) | Number of columns in images for new net| | ||
% | Npre (30) | Number of frames before prediction | | ||
% | Npost (30) | Number of frames after prediction | | ||
% |-------------------------|----------------------------------------| | ||
% | ||
% Example: | ||
% sampleNet = deepinterp.sampleFile(... | ||
% '2019_09_11_23_32_unet_single_1024_mean_absolute_error_Ai93-0450.h5'); | ||
% n = deepinterp.Net.ImageTimeSeries('KerasFile','file',sampleNet,... | ||
% 'Npre',30,'Npost',30); | ||
% | ||
arguments | ||
command (1,:) char {mustBeTextScalar} | ||
options.file (1,:) char {mustBeFile} | ||
options.N (1,1) double {mustBeInteger} = 512; | ||
options.M (1,1) double {mustBeInteger} = 512; | ||
options.Npre (1,1) double {mustBeInteger} = 30; | ||
options.Npost (1,1) double {mustBeInteger} = 30; | ||
end | ||
|
||
superOptions = rmfield(options,{'N','M','Npre','Npost'}); | ||
superInputs = cat(2,{command},... | ||
deepinterp.struct2namevaluepair(superOptions)); | ||
obj = [email protected](superInputs{:}); | ||
obj.Npre = options.Npre; | ||
obj.Npost = options.Npost; | ||
|
||
if strcmp(command,'New'), | ||
disp(['In future we will build a net network here.']); | ||
end; | ||
|
||
obj = obj.SetInputSize(); | ||
|
||
end; % ImageTimeSeries - constructor | ||
|
||
function obj = SetInputSize(obj) | ||
% SETINPUTSIZE - set the input size properties | ||
% | ||
% OBJ = SETINPUTSIZE(OBJ) | ||
% | ||
% Sets the input size properties N and M, and verifies that | ||
% the number of image sequences required in each input add up to | ||
% obj.Npre and obj.Npost. | ||
% | ||
% The values can be examined by OBJ.N, OBJ.M. | ||
% | ||
if isempty(obj.network), | ||
obj.N = NaN; | ||
obj.M = NaN; | ||
else, | ||
inSz = obj.network.Layers(1).InputSize; | ||
obj.N = inSz(1); | ||
obj.M = inSz(2); | ||
assert(inSz(3)==obj.Npre+obj.Npost,... | ||
['Npre + Npost must add up to the total ' ... | ||
'number of image inputs to ' ... | ||
'deepinterp.net.ImageTimeSeries.']); | ||
end; | ||
end; % SetInputSize | ||
|
||
function output = interp(obj, input, options) | ||
% INTERP - apply DeepInterpolation to inputs | ||
% | ||
% OUTPUT = INTERP(DEEPINTERP_IMAGETIMESERIES_OBJ, INPUT) | ||
% | ||
% Apply the DeepInterpolation network to the INPUT data. | ||
% INPUT should be an NxMxT matrix, where N and M match | ||
% the size of the DEEPINTERP_IMAGETIMESERIES_OBJ.N and .M parameters, | ||
% and T must be greater than the Npre + Npost property values of | ||
% DEEPINTERP_IMAGETIMESERIES_OBJ. Interpolation will only be performed | ||
% for frames greater than Npre and less than numel(T) - Npost. The first | ||
% Npre frames and last Npost frames will be equal to the input. | ||
% | ||
arguments | ||
obj (1,1) | ||
input | ||
options.progbar (1,1) logical = true; | ||
end | ||
|
||
output = input; | ||
offsets = [-obj.Npre:-1 1:obj.Npost]; | ||
if options.progbar, | ||
deepinterp.progressbar(); | ||
end; | ||
totalWork = size(input,3)-(obj.Npost+obj.Npre); | ||
for t = obj.Npre+1 : size(input,3) - obj.Npost, | ||
if options.progbar, | ||
deepinterp.progressbar((t-(obj.Npre+1))/totalWork); | ||
end; | ||
output(:,:,t) = predict(obj.network,input(:,:,offsets+t)); | ||
end; | ||
if options.progbar, | ||
progressbar(0.9999999999); | ||
end; | ||
end; % interp() | ||
|
||
% INTERP - apply DeepInterpolation | ||
|
||
|
||
end; % methods | ||
|
||
end % classdef |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
classdef Net | ||
|
||
properties (SetAccess=protected) | ||
network % Deep learning network | ||
end | ||
|
||
|
||
methods | ||
function obj = Net(command,options) | ||
% NET - create a new deepinterp.Net object | ||
% | ||
% OBJ = NET(COMMAND,...) | ||
% | ||
% Create a new deepinterp.Net object according to instructions. | ||
% COMMAND can be: | ||
% | ||
% 'KerasFile': open and import a Keras file | ||
% 'New' : create a new empty network | ||
% | ||
% The function also accepts additional arguments as name/value pairs: | ||
% -------------------------------------------------------------------- | ||
% | Parameter (default) | Description | | ||
% |-------------------------|----------------------------------------| | ||
% | file ('') | Filename to open if reading a file | | ||
% |-------------------------|----------------------------------------| | ||
% | ||
% Example: | ||
% n = deepinterp.Net('KerasFile','file','myKerasFile.H5'); | ||
% | ||
arguments | ||
command (1,:) char {mustBeTextScalar} | ||
options.file (1,:) char {mustBeFile} | ||
end | ||
|
||
switch (command), | ||
case 'KerasFile', | ||
obj.network=deepinterp.importKerasMAE(options.file); | ||
case 'New', | ||
% nothing to do | ||
otherwise, | ||
error(['Unknown command: ' command]); | ||
end; | ||
|
||
end; % Net constructor | ||
|
||
function net_obj = train(net_obj, inputs, outputs) | ||
% TRAIN - train a DeepInterpolation network | ||
% | ||
% NET_OBJ = TRAIN(NET_OBJ, INPUTS, OUTPUTS) | ||
% | ||
% Train a deepinterp.Net object with example | ||
% INPUTS and OUTPUTS. | ||
% | ||
% Subclasses should describe the form of these inputs | ||
% and outputs. | ||
% | ||
|
||
end; % train() | ||
|
||
function output = interp(net_obj, inputs) | ||
% INTERP - apply DeepInterpolation | ||
% | ||
% OUTPUT = INTERP(NET_OBJ, INPUT) | ||
% | ||
% Apply the DeepInterpolation network to the | ||
% INPUT data. | ||
% | ||
% Subclasses should describe the form of the input | ||
% and output. | ||
% | ||
|
||
end; % interp() | ||
end; | ||
|
||
end % classdef | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
function net = importKerasMAE(kerasFile) | ||
% IMPORTKERAS_MAE - import a Keras network, converting placeholder to mae | ||
% | ||
% NET = IMPORTKERASMAE(KERASFILE) | ||
% | ||
% Imports a Keras Deep Learning Network from the file KERASFILE. | ||
% PLACEHOLDER layers are replaced with mean-average-error | ||
% layers. | ||
% | ||
|
||
importednet = importKerasLayers(kerasFile,'ImportWeights',true); | ||
|
||
placeholders = findPlaceholderLayers(importednet); | ||
|
||
regressionnet = replaceLayer(importednet, placeholders.Name , ... | ||
deepinterp.maeRegressionLayer); | ||
|
||
net = assembleNetwork(regressionnet); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
classdef maeRegressionLayer < nnet.layer.RegressionLayer | ||
% Example custom regression layer with mean-absolute-error loss. | ||
|
||
methods | ||
function layer = maeRegressionLayer(name) | ||
% layer = maeRegressionLayer(name) creates a | ||
% mean-absolute-error regression layer and specifies the layer | ||
% name. | ||
|
||
arguments | ||
name(1, 1) string = "MAE Regression" | ||
end | ||
|
||
% Set layer name. | ||
layer.Name = name; | ||
|
||
% Set layer description. | ||
layer.Description = 'Mean absolute error'; | ||
end | ||
|
||
function loss = forwardLoss(layer, Y, T) | ||
% loss = forwardLoss(layer, Y, T) returns the MAE loss between | ||
% the predictions Y and the training targets T. | ||
|
||
% Calculate MAE. | ||
R = size(Y,3); | ||
meanAbsoluteError = sum(abs(Y-T),3)/R; | ||
|
||
% Take mean over mini-batch. | ||
N = size(Y,4); | ||
loss = sum(meanAbsoluteError)/N; | ||
end | ||
end | ||
end |
Oops, something went wrong.