Skip to content

Commit

Permalink
Merge pull request #207 from rcnl-org/FreeFinalTime
Browse files Browse the repository at this point in the history
Added free final time formulation
  • Loading branch information
cvhammond committed Jun 14, 2023
2 parents 8d482e4 + ab77936 commit b55942d
Show file tree
Hide file tree
Showing 20 changed files with 208 additions and 178 deletions.
6 changes: 2 additions & 4 deletions src/DesignOptimization/calcDesignOptimizationIntegrand.m
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
generateCostTermStruct("continuous", "DesignOptimization");
integrand = calcTreatmentOptimizationCost( ...
costTermCalculations, allowedTypes, values, modeledValues, auxdata);
integrand = scaleToBounds(integrand, auxdata.maxIntegral, auxdata.minIntegral);
integrand = integrand ./ (auxdata.maxIntegral - auxdata.minIntegral);
integrand = integrand .^ 2;
end


end
6 changes: 5 additions & 1 deletion src/DesignOptimization/calcDesignOptimizationObjective.m
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
% permissions and limitations under the License. %
% ----------------------------------------------------------------------- %

function objective = calcDesignOptimizationObjective(discrete, continuous)
function objective = calcDesignOptimizationObjective(discrete, ...
continuous, finalTime, inputs)
continuousObjective = sum(continuous) / length(continuous);
if isfield(inputs, "finalTimeRange")
continuousObjective = continuousObjective / finalTime;
end
discreteObjective = sum(discrete) / length(discrete);
if isnan(discreteObjective); discreteObjective = 0; end
objective = continuousObjective + discreteObjective;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
modeledValues, inputs.auxdata);
% discrete = computeStaticParameterCost(inputs);
output.objective = calcDesignOptimizationObjective(discrete, ...
inputs.phase.integral);
inputs.phase.integral, values.time(end), inputs.auxdata);
end

function cost = computeStaticParameterCost(inputs)
Expand Down
10 changes: 7 additions & 3 deletions src/DesignOptimization/computeDesignOptimizationMainFunction.m
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
% ----------------------------------------------------------------------- %

function output = computeDesignOptimizationMainFunction(inputs, params)
bounds = setupProblemBounds(inputs);
guess = setupCommonOptimalControlInitialGuess(inputs);
bounds = setupProblemBounds(inputs, guess);
guess = addUserDefinedTermsToGuess(guess, inputs);
setup = setupCommonOptimalControlSolverSettings(inputs, ...
bounds, guess, params, ...
Expand All @@ -41,7 +41,7 @@
output.solution = solution;
end

function bounds = setupProblemBounds(inputs)
function bounds = setupProblemBounds(inputs, guess)
bounds = setupCommonOptimalControlBounds(inputs);
% setup parameter bounds
if strcmp(inputs.controllerType, 'synergy_driven')
Expand All @@ -63,6 +63,10 @@
0.5];
end
end
if isfield(inputs, "finalTimeRange")
bounds.phase.finaltime.lower = guess.phase.time(end) - (0.5 - guess.phase.time(end));
bounds.phase.finaltime.upper = 0.5;
end
end

function guess = addUserDefinedTermsToGuess(guess, inputs)
Expand All @@ -78,4 +82,4 @@
variable.upper_bounds, ...
variable.lower_bounds)];
end
end
end
12 changes: 7 additions & 5 deletions src/DesignOptimization/getDesignOptimizationValueStruct.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
fnval(params.splineSynergyActivations, values.time);
end
end
for i = 1:length(params.userDefinedVariables)
values.(params.userDefinedVariables{i}.type) = scaleToOriginal( ...
inputs.parameter(i, 1), ...
params.userDefinedVariables{i}.upper_bounds, ...
params.userDefinedVariables{i}.lower_bounds);
if isfield(params, 'userDefinedVariables')
for i = 1:length(params.userDefinedVariables)
values.(params.userDefinedVariables{i}.type) = scaleToOriginal( ...
inputs.parameter(i, 1), ...
params.userDefinedVariables{i}.upper_bounds, ...
params.userDefinedVariables{i}.lower_bounds);
end
end
end
5 changes: 5 additions & 0 deletions src/DesignOptimization/parseDesignOptimizationSettingsTree.m
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@
if(isstruct(maxControlTorques))
inputs.maxControlTorquesMultiple = getDoubleFromField(maxControlTorques);
end
finalTimeRange = getFieldByName(designVariableTree, ...
'final_time_range');
if(isstruct(finalTimeRange))
inputs.finalTimeRange = getDoubleFromField(finalTimeRange);
end
end
inputs.toolName = "DesignOptimization";
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
parseMtpStandard(findFileListFromPrefixList( ...
fullfile(mtpResultsDirectory, "muscleActivations"), inputs.prefixes));
osimxFileName = getFieldByName(tree, "input_osimx_file");
if ~isstruct(osimxFileName)
if ~isstruct(osimxFileName) || isempty(osimxFileName.Text)
throw(MException('', 'An input .osimx file is required if using data from MTP.'))
end
inputs.mtpMuscleData = parseOsimxFile(osimxFileName.Text);
Expand Down
6 changes: 0 additions & 6 deletions src/TrackingOptimization/TrackingOptimizationTool.m
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@
function TrackingOptimizationTool(settingsFileName)
settingsTree = xml2struct(settingsFileName);
[inputs, params] = parseTrackingOptimizationSettingsTree(settingsTree);

tic

[outputs, inputs] = TrackingOptimization(inputs, params);

toc

reportTreatmentOptimizationResults(outputs, inputs);
saveTrackingOptimizationResults(outputs, inputs);
end
2 changes: 0 additions & 2 deletions src/TrackingOptimization/getTrackingOptimizationValueStruct.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,5 @@
values.synergyWeights = getSynergyWeightsFromGroups(...
params.synergyWeightsGuess, params);
end
values.controlSynergyActivations = inputs.control(:, ...
params.numCoordinates + 1 : params.numCoordinates + params.numSynergies);
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@
generateCostTermStruct("continuous", "VerificationOptimization");
integrand = calcTreatmentOptimizationCost( ...
costTermCalculations, allowedTypes, values, modeledValues, auxdata);
integrand = scaleToBounds(integrand, auxdata.maxIntegral, auxdata.minIntegral);
integrand = integrand ./ (auxdata.maxIntegral - auxdata.minIntegral);
integrand = integrand .^ 2;
end
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,28 @@
% permissions and limitations under the License. %
% ----------------------------------------------------------------------- %

function cost = calcTrackingControllerIntegrand(auxdata, values, ...
function cost = calcTrackingControllerIntegrand(auxdata, values, time, ...
controllerName)

switch auxdata.controllerType
case 'synergy_driven'
indx = find(strcmp(convertCharsToStrings( ...
auxdata.synergyLabels), controllerName));
synergyActivations = fnval(auxdata.splineSynergyActivations, values.time)';
if auxdata.splineJointMoments.dim > 1
synergyActivations = fnval(auxdata.splineSynergyActivations, time)';
else
synergyActivations = fnval(auxdata.splineSynergyActivations, time);
end
cost = calcTrackingCostArrayTerm(synergyActivations, values.controlSynergyActivations, indx);
case 'torque_driven'
indx1 = find(strcmp(convertCharsToStrings( ...
auxdata.inverseDynamicMomentLabels), controllerName));
indx2 = find(strcmp(convertCharsToStrings( ...
strcat(auxdata.controlTorqueNames, '_moment')), controllerName));
experimentalJointMoments = fnval(auxdata.splineJointMoments, values.time)';
if auxdata.splineJointMoments.dim > 1
experimentalJointMoments = fnval(auxdata.splineJointMoments, time)';
else
experimentalJointMoments = fnval(auxdata.splineJointMoments, time);
end
cost = experimentalJointMoments(:, indx1) - values.controlTorques(:, indx2);
end
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
% ----------------------------------------------------------------------- %

function inputs = getDesignVariableInputBounds(inputs)
inputs.maxTime = max(inputs.experimentalTime);
if isfield(inputs, "finalTimeRange")
inputs.maxTime = max(inputs.experimentalTime) + inputs.finalTimeRange;
else
inputs.maxTime = max(inputs.experimentalTime);
end
inputs.minTime = min(inputs.experimentalTime);

maxStatePositions = max(inputs.experimentalJointAngles) + ...
Expand Down
3 changes: 2 additions & 1 deletion src/core/TreatmentOptimization/generateCostTermStruct.m
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
costTermCalculations.coordinate_tracking = @(values, modeledValues, auxdata, costTerm) ...
calcTrackingCoordinateIntegrand( ...
auxdata, ...
values.time, ...
values.time/values.time(end), ...
values.statePositions, ...
costTerm.coordinate ...
);
Expand All @@ -106,6 +106,7 @@
calcTrackingControllerIntegrand( ...
auxdata, ...
values, ...
values.time/values.time(end), ...
costTerm.controller ...
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
parseElementTextByNameOrAlternate(tree, "optimizeSynergyVectors", 0));
inputs = parseTreatmentOptimizationDataDirectory(tree, inputs);
inputs.initialGuess = getGpopsInitialGuess(tree);
inputs.experimentalTime = inputs.experimentalTime / ...
inputs.experimentalTime(end);
inputs.costTerms = parseRcnlCostTermSet( ...
getFieldByNameOrError(tree, 'RCNLCostTermSet').RCNLCostTerm);
inputs.path = getPathConstraintTerms(tree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
if ~strcmp(params.controllerType, 'synergy_driven')
values.controlTorques = control(:, params.numCoordinates + 1 : ...
params.numCoordinates + params.numTorqueControls);
else
values.controlSynergyActivations = control(:, ...
params.numCoordinates + 1 : params.numCoordinates + params.numSynergies);
end

end
17 changes: 10 additions & 7 deletions src/core/TreatmentOptimization/inverseDynamics.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
% National Institutes of Health (R01 EB030520). %
% %
% Copyright (c) 2021 Rice University and the Authors %
% Author(s): Spencer Williams %
% Author(s): Spencer Williams, Marleny Vega %
% %
% Licensed under the Apache License, Version 2.0 (the "License"); %
% you may not use this file except in compliance with the License. %
Expand All @@ -25,13 +25,16 @@
% permissions and limitations under the License. %
% ----------------------------------------------------------------------- %

function IDLoads = inverseDynamics(time,q,qp,qpp,IKLabels,AppliedLoads, ...
modelFile)
function inverseDynamicMoments = inverseDynamics(time, jointAngles, ...
jointVelocities, jointAccelerations, coordinateLabels, appliedLoads, ...
modelName)
if isequal(mexext, 'mexw64')
IDLoads = inverseDynamicsMexWindows(time,q,qp,qpp,IKLabels, ...
AppliedLoads);
inverseDynamicMoments = inverseDynamicsMexWindows(time, jointAngles, ...
jointVelocities, jointAccelerations, coordinateLabels, ...
appliedLoads);
else
IDLoads = inverseDynamicsMatlabParallel(time,q,qp,qpp,IKLabels, ...
AppliedLoads,modelFile);
inverseDynamicMoments = inverseDynamicsMatlabParallel(time, ...
jointAngles, jointVelocities, jointAccelerations, coordinateLabels, ...
appliedLoads, modelName);
end
end
Loading

0 comments on commit b55942d

Please sign in to comment.