Skip to content

Commit

Permalink
[system-a] test gradientDescentRms by plane
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed May 21, 2024
1 parent 9e709d3 commit 3f884ff
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
3 changes: 2 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# system-a

[system-a] `gradientDescentRms`
[system-a] `gradientDescentAdam`
[system-a] test `gradientDescentAdam` by `plane`

# the-book

Expand Down
8 changes: 5 additions & 3 deletions src/system-a/gradient-descent/gradientDescentRms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import { smooth } from "./smooth.js"

const stabilizer = 1e-8

export function gradientDescentRms(options: {
// NOTE RMS stands for root mean square.

export function rmsRepresentation(options: {
learningRate: number
decayRate: number
}): Representation<[Tensor, Tensor]> {
Expand All @@ -26,9 +28,9 @@ export function gradientDescentRms(options: {
}
}

export function gradientDescentVelocity(options: {
export function gradientDescentRms(options: {
learningRate: number
decayRate: number
}) {
return gradientDescent(gradientDescentRms(options))
return gradientDescent(rmsRepresentation(options))
}
16 changes: 12 additions & 4 deletions src/system-a/targets/plane.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import assert from "node:assert"
import { test } from "node:test"
import { gradientDescentLonely } from "../gradient-descent/gradientDescentLonely.js"
import { gradientDescentNaked } from "../gradient-descent/gradientDescentNaked.js"
import { gradientDescentRms } from "../gradient-descent/gradientDescentRms.js"
import { gradientDescentVelocity } from "../gradient-descent/gradientDescentVelocity.js"
import type { GradientDescentFn } from "../gradient-descent/index.js"
import { l2Loss } from "../loss.js"
Expand All @@ -25,7 +26,7 @@ test("plane -- extended", () => {
)
})

function testGradientDescentByLine(
function testGradientDescentByPlane(
gradientDescentFn: GradientDescentFn,
options: {
revs: number
Expand Down Expand Up @@ -53,20 +54,27 @@ function testGradientDescentByLine(
}

test("plane -- gradientDescentNaked", () => {
testGradientDescentByLine(gradientDescentNaked({ learningRate: 0.001 }), {
testGradientDescentByPlane(gradientDescentNaked({ learningRate: 0.001 }), {
revs: 15000,
})
})

test("plane -- gradientDescentLonely", () => {
testGradientDescentByLine(gradientDescentLonely({ learningRate: 0.001 }), {
testGradientDescentByPlane(gradientDescentLonely({ learningRate: 0.001 }), {
revs: 15000,
})
})

test("plane -- gradientDescentVelocity", () => {
testGradientDescentByLine(
testGradientDescentByPlane(
gradientDescentVelocity({ learningRate: 0.001, relayFactor: 0.9 }),
{ revs: 5000 },
)
})

test("plane -- gradientDescentRms", () => {
testGradientDescentByPlane(
gradientDescentRms({ learningRate: 0.01, decayRate: 0.9 }),
{ revs: 3000 },
)
})

0 comments on commit 3f884ff

Please sign in to comment.