From fd3de4b92a28a3264a34792ea6182b26b7498e59 Mon Sep 17 00:00:00 2001 From: Xie Yuheng Date: Tue, 21 May 2024 20:16:54 +0800 Subject: [PATCH] [system-a] rename `sqrt` to `squareRoot` --- TODO.md | 2 +- src/system-a/gradient-descent/gradientDescentRms.ts | 4 ++-- src/system-a/toys/toys.ts | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/TODO.md b/TODO.md index e3c24df..06dd1e3 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,6 @@ # system-a -[system-a] rename `sqrt` to `squareRoot` +[the-book] rename `sqrt` to `squareRoot` [system-a] `gradientDescentRms` # the-book diff --git a/src/system-a/gradient-descent/gradientDescentRms.ts b/src/system-a/gradient-descent/gradientDescentRms.ts index b593725..416bc45 100644 --- a/src/system-a/gradient-descent/gradientDescentRms.ts +++ b/src/system-a/gradient-descent/gradientDescentRms.ts @@ -1,5 +1,5 @@ import { tensorZeros, type Tensor } from "../tensor/index.js" -import { add, div, mul, sqrt, square, sub } from "../toys/index.js" +import { add, div, mul, square, squareRoot, sub } from "../toys/index.js" import type { Representation } from "./Representation.js" import { gradientDescent } from "./gradientDescent.js" import { smooth } from "./smooth.js" @@ -19,7 +19,7 @@ export function gradientDescentRms(options: { // NOTE Add `stabilizer` to avoid `div` by zero. const adaptiveLearningRate = div( options.learningRate, - add(stabilizer, sqrt(r)), + add(stabilizer, squareRoot(r)), ) return [sub(p, mul(adaptiveLearningRate, g)), r] }, diff --git a/src/system-a/toys/toys.ts b/src/system-a/toys/toys.ts index e418399..2e91f20 100644 --- a/src/system-a/toys/toys.ts +++ b/src/system-a/toys/toys.ts @@ -34,7 +34,7 @@ export const exptScalar = prim2( (ra, rb, z) => [rb * ra ** (rb - 1) * z, ra ** rb * Math.log(ra) * z], ) -export const sqrtScalar = prim1( +export const squareRootScalar = prim1( (x) => Math.sqrt(x), (ra, z) => (1 / 2) * ra ** (-1 / 2) * z, ) @@ -51,7 +51,7 @@ export const mul = extend2(mulScalar) export const div = extend2(divScalar) export const log = extend1(logScalar) export const expt = extend2(exptScalar) -export const sqrt = extend1(sqrtScalar) +export const squareRoot = extend1(squareRootScalar) export const square = extend1(squareScalar) export const lt = comparator((x, y) => x < y)