diff --git a/src/system-a/models/iris/irisModel.ts b/src/system-a/models/iris/irisModel.ts index 88a2c5c..b9ac661 100644 --- a/src/system-a/models/iris/irisModel.ts +++ b/src/system-a/models/iris/irisModel.ts @@ -34,9 +34,9 @@ export function irisTrainParameters(): Array { const initParameters = denseBlockInitParameters(irisNetwork.shapes) gradientDescentNaked - const gradientDescentFn = gradientDescentNaked({ - learningRate: 0.0002, - }) + // const gradientDescentFn = gradientDescentNaked({ + // learningRate: 0.0002, + // }) gradientDescentRms // const gradientDescentFn = gradientDescentRms({ @@ -45,11 +45,11 @@ export function irisTrainParameters(): Array { // }) gradientDescentAdam - // const gradientDescentFn = gradientDescentAdam({ - // learningRate: 0.001, - // decayRate: 0.9, - // relayFactor: 0.85, - // }) + const gradientDescentFn = gradientDescentAdam({ + learningRate: 0.0001, + decayRate: 0.9, + relayFactor: 0.85, + }) return gradientDescentFn(objective, initParameters, { revs: 2000,