diff --git a/TODO.md b/TODO.md index df99342..1dc286d 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,7 @@ # system-a -[system-a] `gradientDescentRms` +[system-a] `gradientDescentAdam` +[system-a] test `gradientDescentAdam` by `plane` # the-book diff --git a/src/system-a/gradient-descent/gradientDescentRms.ts b/src/system-a/gradient-descent/gradientDescentRms.ts index 416bc45..f7a26f6 100644 --- a/src/system-a/gradient-descent/gradientDescentRms.ts +++ b/src/system-a/gradient-descent/gradientDescentRms.ts @@ -6,7 +6,9 @@ import { smooth } from "./smooth.js" const stabilizer = 1e-8 -export function gradientDescentRms(options: { +// NOTE RMS stands for root mean square. + +export function rmsRepresentation(options: { learningRate: number decayRate: number }): Representation<[Tensor, Tensor]> { @@ -26,9 +28,9 @@ export function gradientDescentRms(options: { } } -export function gradientDescentVelocity(options: { +export function gradientDescentRms(options: { learningRate: number decayRate: number }) { - return gradientDescent(gradientDescentRms(options)) + return gradientDescent(rmsRepresentation(options)) } diff --git a/src/system-a/targets/plane.test.ts b/src/system-a/targets/plane.test.ts index 21f1b9f..e013ab7 100644 --- a/src/system-a/targets/plane.test.ts +++ b/src/system-a/targets/plane.test.ts @@ -2,6 +2,7 @@ import assert from "node:assert" import { test } from "node:test" import { gradientDescentLonely } from "../gradient-descent/gradientDescentLonely.js" import { gradientDescentNaked } from "../gradient-descent/gradientDescentNaked.js" +import { gradientDescentRms } from "../gradient-descent/gradientDescentRms.js" import { gradientDescentVelocity } from "../gradient-descent/gradientDescentVelocity.js" import type { GradientDescentFn } from "../gradient-descent/index.js" import { l2Loss } from "../loss.js" @@ -25,7 +26,7 @@ test("plane -- extended", () => { ) }) -function testGradientDescentByLine( +function testGradientDescentByPlane( gradientDescentFn: GradientDescentFn, options: { revs: number @@ -53,20 +54,27 @@ function testGradientDescentByLine( } test("plane -- gradientDescentNaked", () => { - testGradientDescentByLine(gradientDescentNaked({ learningRate: 0.001 }), { + testGradientDescentByPlane(gradientDescentNaked({ learningRate: 0.001 }), { revs: 15000, }) }) test("plane -- gradientDescentLonely", () => { - testGradientDescentByLine(gradientDescentLonely({ learningRate: 0.001 }), { + testGradientDescentByPlane(gradientDescentLonely({ learningRate: 0.001 }), { revs: 15000, }) }) test("plane -- gradientDescentVelocity", () => { - testGradientDescentByLine( + testGradientDescentByPlane( gradientDescentVelocity({ learningRate: 0.001, relayFactor: 0.9 }), { revs: 5000 }, ) }) + +test("plane -- gradientDescentRms", () => { + testGradientDescentByPlane( + gradientDescentRms({ learningRate: 0.01, decayRate: 0.9 }), + { revs: 3000 }, + ) +})