Skip to content

Commit

Permalink
Tensor operations added.
Browse files Browse the repository at this point in the history
  • Loading branch information
Smoren committed May 16, 2024
1 parent b28f847 commit 78c0777
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/lib/math/operations.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createEmptyMatrix, createEmptyTensor } from './factories';
import type { Tensor } from './types';

export function arrayUnaryOperation<T>(
input: Array<T>,
Expand Down Expand Up @@ -28,6 +29,32 @@ export function arrayBinaryOperation<T>(
return result;
}

export function tensorUnaryOperation<T>(operand: Tensor<T>, operation: (x: T) => T): Tensor<T> {
const result: Tensor<T> = [];
for (const item of operand) {
if (item instanceof Array) {
result.push(tensorUnaryOperation(item, operation));
} else {
result.push(operation(item));
}
}
return result;
}

export function tensorBinaryOperation<T>(lhs: Tensor<T>, rhs: Tensor<T>, operation: (x: T, y: T) => T): Tensor<T> {
const result: Tensor<T> = [];
for (let i = 0; i < lhs.length; ++i) {
const lhsItem = lhs[i];
const rhsItem = rhs[i];
if (lhsItem instanceof Array) {
result.push(tensorBinaryOperation(lhsItem, rhsItem as Tensor<T>, operation));
} else {
result.push(operation(lhsItem, rhsItem as T));
}
}
return result;
}

export function concatArrays<T>(lhs: T[], rhs: T[]): T[] {
return [...lhs, ...rhs];
}
Expand Down
8 changes: 8 additions & 0 deletions src/lib/math/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ export type NumericVector = Array<number>;
*/
export type ImmutableNumericVector = ReadonlyArray<number>;

type TensorItem<T> = T | TensorItem<T>[];

/**
* Multi-dimensional tensor
* @public
*/
export type Tensor<T> = TensorItem<T>[];

/**
* Interface of vector
*/
Expand Down
310 changes: 310 additions & 0 deletions tests/math/operations.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
import { describe, expect, it } from '@jest/globals'
import type { Tensor } from "../../src/lib/math/types";
import { tensorBinaryOperation, tensorUnaryOperation } from '../../src/lib/math/operations';

describe.each([
...dataProviderForTensorUnaryOperation(),
] as Array<[Tensor<number>, (x: number) => number, Tensor<number>]>)(
'Tensor Unary Operation Test',
(tensor, operation, expected) => {
it('', () => {
const result = tensorUnaryOperation(tensor, operation);
expect(result).toEqual(expected);
});
},
);

function dataProviderForTensorUnaryOperation(): Array<[Tensor<number>, (x: number) => number, Tensor<number>]> {
return [
[
[],
(x: number) => x*2,
[],
],
[
[1],
(x: number) => x*2,
[2],
],
[
[1, 2, 3],
(x: number) => x*2,
[2, 4, 6],
],
[
[
[1, 2, 3],
[4, 5, 6],
],
(x: number) => x*2,
[
[2, 4, 6],
[8, 10, 12],
],
],
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
(x: number) => x*2,
[
[2, 4, 6],
[8, 10, 12],
[14, 16, 18],
],
],
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
],
(x: number) => x*2,
[
[
[2, 4, 6],
[8, 10, 12],
[14, 16, 18],
],
],
],
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
],
(x: number) => x*2,
[
[
[2, 4, 6],
[8, 10, 12],
[14, 16, 18],
],
[
[20, 40, 60],
[80, 100, 120],
[140, 160, 180],
],
],
],
[
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
],
],
(x: number) => x*2,
[
[
[
[2, 4, 6],
[8, 10, 12],
[14, 16, 18],
],
[
[20, 40, 60],
[80, 100, 120],
[140, 160, 180],
],
],
],
],
];
}

describe.each([
...dataProviderForTensorBiaryOperation(),
] as Array<[Tensor<number>, Tensor<number>, (lhs: number, rhs: number) => number, Tensor<number>]>)(
'Tensor Binary Operation Test',
(lhs, rhs, operation, expected) => {
it('', () => {
const result = tensorBinaryOperation(lhs, rhs, operation);
expect(result).toEqual(expected);
});
},
);

function dataProviderForTensorBiaryOperation(): Array<[Tensor<number>, Tensor<number>, (lhs: number, rhs: number) => number, Tensor<number>]> {
return [
[
[],
[],
(lhs: number, rhs: number) => lhs + rhs,
[],
],
[
[1],
[2],
(lhs: number, rhs: number) => lhs + rhs,
[3],
],
[
[1, 2, 3],
[10, 20, 30],
(lhs: number, rhs: number) => lhs + rhs,
[11, 22, 33],
],
[
[
[1, 2, 3],
[4, 5, 6],
],
[
[10, 20, 30],
[40, 50, 60],
],
(lhs: number, rhs: number) => lhs + rhs,
[
[11, 22, 33],
[44, 55, 66],
],
],
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
(lhs: number, rhs: number) => lhs + rhs,
[
[11, 22, 33],
[44, 55, 66],
[77, 88, 99],
],
],
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
],
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
],
(lhs: number, rhs: number) => lhs + rhs,
[
[
[2, 4, 6],
[8, 10, 12],
[14, 16, 18],
],
],
],
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
],
[
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
[
[100, 200, 300],
[400, 500, 600],
[700, 800, 900],
],
],
(lhs: number, rhs: number) => lhs + rhs,
[
[
[11, 22, 33],
[44, 55, 66],
[77, 88, 99],
],
[
[110, 220, 330],
[440, 550, 660],
[770, 880, 990],
],
],
],
[
[
[
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
],
],
[
[
[
[10, 20, 30],
[40, 50, 60],
[70, 80, 90],
],
[
[100, 200, 300],
[400, 500, 600],
[700, 800, 900],
],
],
],
(lhs: number, rhs: number) => lhs + rhs,
[
[
[
[11, 22, 33],
[44, 55, 66],
[77, 88, 99],
],
[
[110, 220, 330],
[440, 550, 660],
[770, 880, 990],
],
],
],
],
];
}

0 comments on commit 78c0777

Please sign in to comment.