Skip to content

Commit

Permalink
[system-a] irisParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed May 28, 2024
1 parent a5e7ebf commit 18f2fa6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
22 changes: 22 additions & 0 deletions src/system-a/block/denseBlock.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
import { relu } from "../neurons/relu.js"
import type { Tensor } from "../tensor/Tensor.js"
import { randomTensor } from "../tensor/randomTensor.js"
import type { Shape } from "../tensor/shape.js"
import { zeroTensor } from "../tensor/zeroTensor.js"
import { Block, type BlockFn } from "./Block.js"

export function denseBlock(inputSize: number, layerWidth: number): Block {
return Block(relu as BlockFn, [[layerWidth, inputSize], [layerWidth]])
}

export function denseInitParameters(shapes: Array<Shape>): Array<Tensor> {
return shapes.map(denseInitParameter)
}

export function denseInitParameter(shape: Shape): Tensor {
if (shape.length === 1) {
return zeroTensor(shape)
}

if (shape.length === 2) {
const mean = 0
const deviation = 2 / shape[1]
return randomTensor(mean, deviation, shape)
}

throw new Error(`[denseInitParameter] Wrong shape: ${shape}`)
}
38 changes: 21 additions & 17 deletions src/system-a/models/irisModel.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
import { blockStack, denseBlock } from "../block/index.js"
import { blockStack, denseBlock, denseInitParameters } from "../block/index.js"
import { gradientDescentNaked } from "../gradient-descent/gradientDescentNaked.js"
import { l2Loss } from "../loss.js"
import type { Tensor } from "../tensor/Tensor.js"
import { randomTensor } from "../tensor/randomTensor.js"
import type { Shape } from "../tensor/shape.js"
import { zeroTensor } from "../tensor/zeroTensor.js"
import { samplingObjective } from "../tensor/samplingObjective.js"
import { irisTrainXs, irisTrainYs } from "./irisDataset.js"

export const irisNetwork = blockStack([denseBlock(4, 6), denseBlock(6, 3)])

function initParameters(shapes: Array<Shape>): Array<Tensor> {
return shapes.map(initShape)
}
export function irisParameters(): Array<Tensor> {
const objective = samplingObjective(
l2Loss(irisNetwork.fn),
irisTrainXs,
irisTrainYs,
{
batchSize: 8,
},
)

function initShape(shape: Shape): Tensor {
if (shape.length === 1) {
return zeroTensor(shape)
}
const gradientDescentFn = gradientDescentNaked({
learningRate: 0.0002,
})

if (shape.length === 2) {
const mean = 0
const deviation = 2 / shape[1]
return randomTensor(mean, deviation, shape)
}
const initParameters = denseInitParameters(irisNetwork.shapes)

throw new Error(`[initShape] Wrong shape: ${shape}`)
return gradientDescentFn(objective, initParameters, {
revs: 2000,
})
}

0 comments on commit 18f2fa6

Please sign in to comment.