Skip to content

Commit

Permalink
made pretrained models folder; made pretrained models directly builda…
Browse files Browse the repository at this point in the history
…ble by name only; made a single Net class that handles all known situations
  • Loading branch information
stevevanhooser committed May 12, 2024
1 parent b3dda31 commit aa9fd22
Show file tree
Hide file tree
Showing 21 changed files with 276 additions and 170 deletions.
137 changes: 0 additions & 137 deletions +deepinterp/+Net/ImageTimeSeries.m

This file was deleted.

File renamed without changes.
48 changes: 48 additions & 0 deletions +deepinterp/+internal/getPretrainedModelFilename.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function [filename,parameters] = getPretrainedModelFilename(modelName)
% GETPRETRAINEDMODELFILENAME - get the file name and parameters for a pretrained model
%
% [FILENAME, PARAMETERS] = GETPRETRAINEDMODELFILENAME(MODELNAME)
%
% Returns the FILENAME and the pretrained model PARAMETERS for a given MODELNAME.
%
% Alternatively, one may call the function with no input arguments for a list
% of valid MODELNAME entries:
%
% [ALLMODELNAMES] = GETPRETRAINEDMODELFILENAME()
%
% Example:
% [filename,parameters] = deepinterp.internal.getPretrainedModelFilenam('TP-Ai93-450');
%

filename = [];
parameters = [];

[allModelNames,allModelParameters] = deepinterp.getPretrainedModels();

if nargin<1,
filename = allModelNames;
return;
end;

idx = find(strcmpi(modelName,allModelNames));

if isempty(idx),
error(['No match for ' modelName ' in known models.']);
end;

parameters = allModelParameters(idx);

filename = fullfile(deepinterp.toolboxpath,'pretrainedModels',parameters.filename);

if ~isfile(filename),
% try to download
FileURL = parameters.url;
disp(['Downloading pretrained model...']);
try,
websave(filename,FileURL);
catch,
error(['Error downloading model.']);
throw(ex);
end;
end;

9 changes: 5 additions & 4 deletions +deepinterp/+internal/importKerasMAE.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
importednet = importKerasLayers(kerasFile,'ImportWeights',true);

placeholders = findPlaceholderLayers(importednet);
if ~isempty(placeholders),
importednet = replaceLayer(importednet, placeholders.Name , ...
deepinterp.internal.maeRegressionLayer);
end;

regressionnet = replaceLayer(importednet, placeholders.Name , ...
deepinterp.maeRegressionLayer);

net = assembleNetwork(regressionnet);
net = assembleNetwork(importednet);

16 changes: 16 additions & 0 deletions +deepinterp/+internal/textfile2char.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function str = textfile2char(filename)
% TEXTFILE2CHAR - Read a text file into a character string
%
% STR = TEXTFILE2CHAR(FILENAME)
%
% This function reads the entire contents of the file FILENAME into
% the character string STR.
%

fid = fopen(filename,'rt');
if fid>0,
str = char(fread(fid,Inf))';
fclose(fid);
else,
error(['Could not open text file named ' filename '.']);
end;
Loading

0 comments on commit aa9fd22

Please sign in to comment.