Skip to content

Commit edca927

Browse files
committed
[system-a] tensor/ -- extract assertions
1 parent 30c667e commit edca927

File tree

5 files changed

+35
-34
lines changed

5 files changed

+35
-34
lines changed

TODO.md

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# system-a
22

3-
[system-a] `tensor/` -- extract `assert`
4-
53
[system-a] `tensor/` -- `tensorZeros` -- zeros of the same shape
64
[system-a] `velocityRepresentation` -- takes `options: { velocityAccumulationFactor }`
75
[system-a] `gradientDescentVelocity`

src/system-a/tensor/Scalar.ts

-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@ export function isScalar(x: any): x is Scalar {
2020
return typeof x === "number" || isDual(x)
2121
}
2222

23-
export function assertScalar(t: any): asserts t is Scalar {
24-
if (!isScalar(t)) {
25-
throw new Error(`[assertScalar] ${t}`)
26-
}
27-
}
28-
2923
export function scalarReal(x: Scalar): number {
3024
if (isDual(x)) {
3125
return x.real

src/system-a/tensor/Tensor.ts

-26
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,7 @@
11
import { isScalar, type Scalar } from "./Scalar.js"
2-
import { rank } from "./rank.js"
3-
import { tensorAlmostEqual } from "./tensorAlmostEqual.js"
42

53
export type Tensor = Scalar | Array<Tensor>
64

75
export function isTensor(x: any): x is Tensor {
86
return isScalar(x) || (x instanceof Array && x.every(isTensor))
97
}
10-
11-
export function assertTensor1(t: Tensor): asserts t is Array<Scalar> {
12-
if (rank(t) !== 1) {
13-
throw new Error(`[assertTensor1] ${t}`)
14-
}
15-
}
16-
17-
export function assertTensorArray(x: any): asserts x is Array<Tensor> {
18-
if (x instanceof Array && x.every(isTensor)) {
19-
return
20-
}
21-
22-
throw new Error(`[assertTensorArray] ${x}`)
23-
}
24-
25-
export function assertTensorAlmostEqual(
26-
x: Tensor,
27-
y: Tensor,
28-
epsilon: number,
29-
): void {
30-
if (!tensorAlmostEqual(x, y, epsilon)) {
31-
throw new Error(`[assertTensorAlmostEqual] [${x}], [${y}], ${epsilon}`)
32-
}
33-
}

src/system-a/tensor/assertions.ts

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { isScalar, type Scalar } from "./Scalar.js"
2+
import { isTensor, type Tensor } from "./Tensor.js"
3+
import { rank } from "./rank.js"
4+
import { tensorAlmostEqual } from "./tensorAlmostEqual.js"
5+
6+
export function assertScalar(t: any): asserts t is Scalar {
7+
if (!isScalar(t)) {
8+
throw new Error(`[assertScalar] ${t}`)
9+
}
10+
}
11+
12+
export function assertTensor1(t: Tensor): asserts t is Array<Scalar> {
13+
if (rank(t) !== 1) {
14+
throw new Error(`[assertTensor1] ${t}`)
15+
}
16+
}
17+
18+
export function assertTensorArray(x: any): asserts x is Array<Tensor> {
19+
if (x instanceof Array && x.every(isTensor)) {
20+
return
21+
}
22+
23+
throw new Error(`[assertTensorArray] ${x}`)
24+
}
25+
26+
export function assertTensorAlmostEqual(
27+
x: Tensor,
28+
y: Tensor,
29+
epsilon: number,
30+
): void {
31+
if (!tensorAlmostEqual(x, y, epsilon)) {
32+
throw new Error(`[assertTensorAlmostEqual] [${x}], [${y}], ${epsilon}`)
33+
}
34+
}

src/system-a/tensor/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export * from "./Scalar.js"
22
export * from "./Tensor.js"
3+
export * from "./assertions.js"
34
export * from "./rank.js"
45
export * from "./shape.js"
56
export * from "./tensorAlmostEqual.js"

0 commit comments

Comments
 (0)