Skip to content

Commit f93d1c4

Browse files
committed
[system-a] extract gradient-descent/
1 parent f86a855 commit f93d1c4

22 files changed

+45
-63
lines changed

TODO.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
[system-a] `Representation` -- `{ inflate, deflate, update }`
44

5-
- naked representation
6-
- lonely representation
5+
- with `velocityAccumulationFactor`
76

7+
[system-a] `nakedRepresentation`
8+
[system-a] `lonelyRepresentation`
89
[system-a] `velocityRepresentation` as an instance of `Representation`
910

1011
# the-book

src/system-a/Scalar.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import {
22
gradientStateGetWithDefault,
33
gradientStateSet,
44
type GradientState,
5-
} from "./index.js"
5+
} from "./gradient-descent/index.js"
66

77
export type Dual = { "@type": "Dual"; real: number; link: Link }
88

@@ -26,6 +26,12 @@ export function isScalar(x: any): x is Scalar {
2626
return typeof x === "number" || isDual(x)
2727
}
2828

29+
export function assertScalar(t: any): asserts t is Scalar {
30+
if (!isScalar(t)) {
31+
throw new Error(`[assertScalar] ${t}`)
32+
}
33+
}
34+
2935
export function scalarReal(x: Scalar): number {
3036
if (isDual(x)) {
3137
return x.real

src/system-a/Tensor.ts

+1-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { isScalar, scalarReal, type Scalar } from "./index.js"
1+
import { isScalar, scalarReal, type Scalar } from "./Scalar.js"
22
import { sub } from "./toys/index.js"
33

44
export type Tensor = Scalar | Array<Tensor>
@@ -21,18 +21,6 @@ export function rank(t: Tensor): number {
2121
return shape(t).length
2222
}
2323

24-
export function assertScalar(t: Tensor): asserts t is Scalar {
25-
if (!isScalar(t)) {
26-
throw new Error(`[assertScalar] ${t}`)
27-
}
28-
}
29-
30-
export function assertNotScalar(t: Tensor): asserts t is Array<Tensor> {
31-
if (isScalar(t)) {
32-
throw new Error(`[assertNotScalar] ${t}`)
33-
}
34-
}
35-
3624
export function assertTensor1(t: Tensor): asserts t is Array<Scalar> {
3725
if (rank(t) !== 1) {
3826
throw new Error(`[assertTensor1] ${t}`)

src/system-a/GradientState.ts renamed to src/system-a/gradient-descent/GradientState.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type Scalar } from "./index.js"
1+
import { type Scalar } from "../Scalar.js"
22

33
export type GradientState = Map<Scalar, number>
44

src/system-a/gradient.test.ts renamed to src/system-a/gradient-descent/gradient.test.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import assert from "node:assert"
22
import { test } from "node:test"
3-
import { gradient } from "./index.js"
4-
import { add, mul, sum } from "./toys/index.js"
3+
import { add, mul, sum } from "../toys/index.js"
4+
import { gradient } from "./gradient.js"
55

66
test("gradient -- add", () => {
77
assert.deepStrictEqual(gradient(add, [1, 1]), [1, 1])

src/system-a/gradient.ts renamed to src/system-a/gradient-descent/gradient.ts

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1+
import { isScalar, scalarLink, scalarTruncate, type Scalar } from "../Scalar.js"
2+
import { tensorMap, type Tensor } from "../Tensor.js"
13
import {
24
emptyGradientState,
35
gradientStateGetWithDefault,
4-
isScalar,
5-
scalarLink,
6-
scalarTruncate,
7-
tensorMap,
86
type GradientState,
9-
type Scalar,
10-
type Tensor,
11-
} from "./index.js"
7+
} from "./GradientState.js"
128

139
// The effect of `gradient` on a `DifferentiableFn`
1410
// is `sum` of all elements of it's result tensor.

src/system-a/gradientDescent.ts renamed to src/system-a/gradient-descent/gradientDescent.ts

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import { zip } from "../utils/zip.js"
2-
import {
3-
assertTensorArray,
4-
gradient,
5-
tensorReal,
6-
type Scalar,
7-
type Tensor,
8-
} from "./index.js"
9-
import { mul, sub } from "./toys/index.js"
1+
import { zip } from "../../utils/zip.js"
2+
import type { Scalar } from "../Scalar.js"
3+
import { assertTensorArray, tensorReal, type Tensor } from "../Tensor.js"
4+
import { mul, sub } from "../toys/index.js"
5+
import { gradient } from "./index.js"
106

117
export function gradientDescent(
128
objective: (...ps: Array<Tensor>) => Scalar,
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
export * from "./GradientState.js"
2-
export * from "./Scalar.js"
3-
export * from "./Tensor.js"
42
export * from "./gradient.js"
3+
export * from "./gradientDescent.js"

src/system-a/loss.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import { assertScalar, type Scalar, type Tensor } from "./index.js"
1+
import { assertScalar, type Scalar } from "./Scalar.js"
2+
import { type Tensor } from "./Tensor.js"
23
import { square, sub, sum } from "./toys/index.js"
34

45
export type Target = (xs: Tensor) => (...ps: Array<Tensor>) => Tensor

src/system-a/targets/line.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import assert from "assert"
22
import { test } from "node:test"
33
import { assertTensorAlmostEqual, tensorReal } from "../Tensor.js"
4-
import { gradientDescent } from "../gradientDescent.js"
4+
import { gradientDescent } from "../gradient-descent/index.js"
55
import { l2Loss } from "../loss.js"
66
import { samplingObjective } from "../samplingObjective.js"
77
import { line } from "./line.js"

src/system-a/targets/line.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type Tensor } from "../index.js"
1+
import { type Tensor } from "../Tensor.js"
22
import { add, mul } from "../toys/index.js"
33

44
export function line(x: Tensor): (...ps: [Tensor, Tensor]) => Tensor {

src/system-a/targets/plane.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import assert from "node:assert"
22
import { test } from "node:test"
33
import { assertTensorAlmostEqual, tensorReal } from "../Tensor.js"
4-
import { gradientDescent } from "../gradientDescent.js"
4+
import { gradientDescent } from "../gradient-descent/index.js"
55
import { l2Loss } from "../loss.js"
66
import { samplingObjective } from "../samplingObjective.js"
77
import { plane } from "./plane.js"

src/system-a/targets/plane.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type Tensor } from "../index.js"
1+
import { type Tensor } from "../Tensor.js"
22
import { dot } from "../toys/dot.js"
33
import { add } from "../toys/index.js"
44

src/system-a/targets/quad.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import assert from "node:assert"
22
import { test } from "node:test"
33
import { assertTensorAlmostEqual, tensorReal } from "../Tensor.js"
4-
import { gradientDescent } from "../gradientDescent.js"
4+
import { gradientDescent } from "../gradient-descent/index.js"
55
import { l2Loss } from "../loss.js"
66
import { samplingObjective } from "../samplingObjective.js"
77
import { quad } from "./quad.js"

src/system-a/targets/quad.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type Tensor } from "../index.js"
1+
import { type Tensor } from "../Tensor.js"
22
import { add, mul, square } from "../toys/index.js"
33

44
export function quad(x: Tensor): (...ps: [Tensor, Tensor, Tensor]) => Tensor {

src/system-a/toys/comparator.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { scalarReal, type Scalar } from "../index.js"
1+
import { scalarReal, type Scalar } from "../Scalar.js"
22

33
export function comparator(
44
p: (x: number, y: number) => boolean,

src/system-a/toys/dot.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type Tensor } from "../index.js"
1+
import { type Tensor } from "../Tensor.js"
22
import { mul, sum } from "./index.js"
33

44
export function dot(w: Tensor, x: Tensor): Tensor {

src/system-a/toys/extend.ts

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
import { zip } from "../../utils/zip.js"
2-
import {
3-
assertNotScalar,
4-
isScalar,
5-
rank,
6-
type Scalar,
7-
type Tensor,
8-
} from "../index.js"
2+
import { isScalar, type Scalar } from "../Scalar.js"
3+
import { assertTensorArray, rank, type Tensor } from "../Tensor.js"
94

105
export function extend2(
116
fn: (x: Scalar, y: Scalar) => Scalar,
127
): (x: Tensor, y: Tensor) => Tensor {
138
return function extendedFn(x: Tensor, y: Tensor): Tensor {
149
if (rank(x) > rank(y)) {
15-
assertNotScalar(x)
10+
assertTensorArray(x)
1611
return x.map((x) => extendedFn(x, y))
1712
}
1813

1914
if (rank(x) < rank(y)) {
20-
assertNotScalar(y)
15+
assertTensorArray(y)
2116
return y.map((y) => extendedFn(x, y))
2217
}
2318

@@ -29,8 +24,8 @@ export function extend2(
2924
return fn(x, y)
3025
}
3126

32-
assertNotScalar(x)
33-
assertNotScalar(y)
27+
assertTensorArray(x)
28+
assertTensorArray(y)
3429

3530
return zip(x, y).map(([x, y]) => extendedFnSameShape(x, y))
3631
}

src/system-a/toys/prim.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Dual, scalarLink, scalarReal, type Scalar } from "../index.js"
1+
import { Dual, scalarLink, scalarReal, type Scalar } from "../Scalar.js"
22

33
export function prim1(
44
realFn: (ra: number) => number,

src/system-a/toys/sum.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1+
import type { Scalar } from "../Scalar.js"
12
import {
2-
assertNotScalar,
33
assertTensor1,
4+
assertTensorArray,
45
rank,
5-
type Scalar,
66
type Tensor,
7-
} from "../index.js"
7+
} from "../Tensor.js"
88
import { addScalar } from "./index.js"
99

1010
export function sum1(xs: Array<Scalar>): Scalar {
1111
return xs.reduce((x, result) => addScalar(x, result), 0)
1212
}
1313

1414
export function sum(x: Tensor): Tensor {
15-
assertNotScalar(x)
15+
assertTensorArray(x)
1616
if (rank(x) === 1) {
1717
assertTensor1(x)
1818
return sum1(x)

src/system-a/toys/toys-by-hand.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Dual, scalarLink, scalarReal, type Scalar } from "../index.js"
1+
import { Dual, scalarLink, scalarReal, type Scalar } from "../Scalar.js"
22

33
function addScalar(da: Scalar, db: Scalar): Scalar {
44
return Dual(scalarReal(da) + scalarReal(db), (_d, z, state) => {

0 commit comments

Comments
 (0)