Skip to content

Commit 6fe879c

Browse files
committed
[system-a] tensor/ -- extract Link
1 parent edca927 commit 6fe879c

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed

src/system-a/tensor/Link.ts

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import {
2+
gradientStateGetWithDefault,
3+
gradientStateSet,
4+
type GradientState,
5+
} from "../gradient-descent/index.js"
6+
import { type Scalar } from "./Scalar.js"
7+
8+
export type Link = (
9+
y: Scalar,
10+
accumulator: number,
11+
state: GradientState,
12+
) => GradientState
13+
14+
export function endOfChain(
15+
d: Scalar,
16+
z: number,
17+
state: GradientState,
18+
): GradientState {
19+
const g = gradientStateGetWithDefault(state, d, 0)
20+
return gradientStateSet(state, d, z + g)
21+
}

src/system-a/tensor/Scalar.ts

+1-20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import {
2-
gradientStateGetWithDefault,
3-
gradientStateSet,
4-
type GradientState,
5-
} from "../gradient-descent/index.js"
1+
import { endOfChain, type Link } from "./Link.js"
62

73
export type Dual = { "@type": "Dual"; real: number; link: Link }
84

@@ -39,18 +35,3 @@ export function scalarLink(x: Scalar): Link {
3935
export function scalarTruncate(x: Scalar): Scalar {
4036
return Dual(scalarReal(x), endOfChain)
4137
}
42-
43-
export type Link = (
44-
y: Scalar,
45-
accumulator: number,
46-
state: GradientState,
47-
) => GradientState
48-
49-
export function endOfChain(
50-
d: Scalar,
51-
z: number,
52-
state: GradientState,
53-
): GradientState {
54-
const g = gradientStateGetWithDefault(state, d, 0)
55-
return gradientStateSet(state, d, z + g)
56-
}

src/system-a/tensor/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
export * from "./Link.js"
12
export * from "./Scalar.js"
23
export * from "./Tensor.js"
34
export * from "./assertions.js"

0 commit comments

Comments
 (0)