Skip to content

Commit

Permalink
Added batch norm
Browse files Browse the repository at this point in the history
  • Loading branch information
pashashiz committed Nov 12, 2023
1 parent 7c1ea4a commit 4bc04d2
Show file tree
Hide file tree
Showing 27 changed files with 483 additions and 105 deletions.
6 changes: 6 additions & 0 deletions src/main/scala/scanet/core/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ case class DependsOn[A: TensorType](expr: Expr[A], dep: Expr[_]) extends Expr[A]
override def inputs: Seq[Expr[_]] = Seq(expr)
override def controls: Seq[Expr[_]] = Seq(dep)
override def compiler: Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] =
List(parentGrad)
}
}

case class Switch[A: TensorType](cond: Expr[Boolean], output: Expr[A]) extends Expr[(A, A)] {
Expand Down
62 changes: 49 additions & 13 deletions src/main/scala/scanet/core/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
indexes.reverse
}

def apply(dim: Int): Int = dims(dim)
def get(dim: Int): Int = dims(dim)
def get(dim: Int): Int = if (dim == -1) last else dims(dim)

def apply(dim: Int): Int = get(dim)

def rank: Int = dims.size

def axis: List[Int] = dims.indices.toList

def axisExcept(other: Int*): List[Int] = {
val indexedAxis = indexAxis(other)
(dims.indices.toSet -- indexedAxis.toSet).toList.sorted
}

def isScalar: Boolean = rank == 0

def isInBound(projection: Projection): Boolean = {
Expand All @@ -52,7 +60,10 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
def drop(n: Int): Shape = Shape(dims.drop(n))
def dropRight(n: Int): Shape = Shape(dims.dropRight(n))

def last: Int = dims.last
def last: Int = {
require(!isScalar, "cannot get last dimension for scalar")
dims.last
}

def prepend(dim: Int): Shape = Shape(dim +: dims: _*)
def +:(dim: Int): Shape = prepend(dim)
Expand Down Expand Up @@ -159,47 +170,72 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
}
}

def permute(indexes: Int*): Shape = {
def maxDims(other: Shape): Shape = {
val maxRank = rank max other.rank
val left = alignLeft(maxRank, 1)
val right = other.alignLeft(maxRank, 2)
val dimsResult = left.dims.zip(right.dims)
.map { case (l, r) => l max r }
Shape(dimsResult)
}

def permute(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
rank == indexes.size,
rank == indexedAxis.size,
"the number of permutation indexes " +
s"should be equal to rank $rank, but was (${indexes.mkString(", ")})")
Shape(indexes.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
s"should be equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
}

def select(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of selected axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(axis.map(dims(_)).toList)
Shape(indexedAxis.map(get).toList)
}

def remove(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of removed axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
val filteredDims = dims.zipWithIndex
.filter { case (_, i) => !axis.contains(i) }
.filter {
case (_, i) =>
!indexedAxis.contains(i)
}
.map { case (dim, _) => dim }
Shape(filteredDims)
}

def updated(axis: Int, value: Int): Shape = updateAll(value)(axis)

def updateAll(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of updated axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
val updatedDims = dims.zipWithIndex.map {
case (dim, i) =>
if (axis.contains(i)) value else dim
if (indexedAxis.contains(i)) value else dim
}
Shape(updatedDims)
}

def updateAllExcept(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
val axisToUpdate = dims.indices.toSet -- indexedAxis.toSet
updateAll(value)(axisToUpdate.toList: _*)
}

private def indexAxis(axis: Seq[Int]): Seq[Int] =
axis.map(a => if (a == -1) dims.size - 1 else a)

def minus(other: Shape): Shape = {
require(broadcastableAny(other), s"cannot $this - $other")
if (endsWith(other)) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/scanet/core/core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ package object core {

object syntax extends CoreSyntax

def error(message: String): Nothing = throw new RuntimeException(message)
def error(message: String): Nothing =
throw new RuntimeException(message)

def memoize[I1, O](f: I1 => O): I1 => O = {
val cache = mutable.HashMap[I1, O]()
Expand Down
12 changes: 9 additions & 3 deletions src/main/scala/scanet/math/alg/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ case class Negate[A: Numeric](expr: Expr[A]) extends Expr[A] {
}

case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends Expr[A] {
if (!left.broadcastableAny(right)) {
println(s"$left * $right")
println(s"!!!")
}
require(
left.broadcastableAny(right),
s"cannot multiply tensors with shapes ${left.shape} * ${right.shape}")
override def name: String = "Mul"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = left.shape max right.shape
override val shape: Shape = left.shape maxDims right.shape
override def inputs: Seq[Expr[_]] = Seq(left, right)
override def compiler: core.Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
Expand Down Expand Up @@ -233,8 +237,10 @@ case class Mean[A: Numeric] private (expr: Expr[A], axis: Seq[Int], keepDims: Bo
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
// we need to recover reduced axis with 1, cause broadcasting will not always work
val parentShape = axis.foldLeft(parentGrad.shape)((s, axis) => s.insert(axis, 1))
val parentShape =
if (keepDims) parentGrad.shape
// we need to recover reduced axis with 1, cause broadcasting will not always work
else axis.foldLeft(parentGrad.shape)((s, axis) => s.insert(axis, 1))
val size = expr.shape.select(axis: _*).power
List(kernels.ones[R](expr.shape) * parentGrad.reshape(parentShape) / size.const.cast[R])
}
Expand Down
10 changes: 6 additions & 4 deletions src/main/scala/scanet/models/Math.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@ import scanet.core.Params.Weights
import scanet.core.{Expr, Floating, Params, Shape}
import scanet.math.syntax._
import scanet.models.Aggregation.Avg
import scanet.models.layer.StatelessLayer
import scanet.models.layer.{Layer, StatelessLayer}

object Math {

case object `x^2` extends StatelessLayer {
case class `x^2`(override val trainable: Boolean = true) extends StatelessLayer {

override def params(input: Shape): Params[ParamDef] =
Params(Weights -> ParamDef(Shape(), Initializer.Zeros, Some(Avg), trainable = true))

override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] =
override def buildStateless[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] =
pow(params(Weights), 2)

override def penalty[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] =
override def penalty[E: Floating](params: Params[Expr[E]]): Expr[E] =
zeros[E](Shape())

override def outputShape(input: Shape): Shape = input

override def makeTrainable(trainable: Boolean): Layer = copy(trainable = trainable)
}
}
28 changes: 22 additions & 6 deletions src/main/scala/scanet/models/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ abstract class Model extends Serializable {
* @param params initialized or calculated model params
* @return penalty
*/
def penalty[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E]
def penalty[E: Floating](params: Params[Expr[E]]): Expr[E]

def result[E: Floating]: (Expr[E], Params[Expr[E]]) => Expr[E] =
(input, params) => build(input, params)._1
Expand All @@ -40,6 +40,11 @@ abstract class Model extends Serializable {

def outputShape(input: Shape): Shape

def trainable: Boolean
def makeTrainable(trainable: Boolean): Model
def freeze: Model = makeTrainable(false)
def unfreeze: Model = makeTrainable(true)

def withLoss(loss: Loss): LossModel = LossModel(this, loss)

private def makeGraph[E: Floating](input: Shape): Expr[E] =
Expand Down Expand Up @@ -89,15 +94,14 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
output: Expr[E],
params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = {
val (result, nextParams) = model.build(input, params)
val loss = lossF.build(result, output) plus model.penalty(input.shape, params)
val loss = lossF.build(result, output) plus model.penalty(params)
(loss, nextParams)
}

def loss[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Expr[E] =
(input, output, params) => buildStateful(input, output, params)._1

def lossStateful[E: Floating]
: (Expr[E], Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) =
def lossStateful[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) =
(input, output, params) => buildStateful(input, output, params)

def grad[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Params[Expr[E]] =
Expand All @@ -114,7 +118,13 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
(grad, nextState)
}

def trained[E: Floating](params: Params[Tensor[E]]) = new TrainedModel(this, params)
def trainable: Boolean = model.trainable
def makeTrainable(trainable: Boolean): LossModel = copy(model = model.makeTrainable(trainable))
def freeze: LossModel = makeTrainable(false)
def unfreeze: LossModel = makeTrainable(true)

def trained[E: Floating](params: Params[Tensor[E]]): TrainedModel[E] =
TrainedModel(this.freeze, params)

def displayLoss[E: Floating](input: Shape, dir: String = ""): Unit = {
val params = model.params(input)
Expand All @@ -141,7 +151,7 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
override def toString: String = s"$lossF($model)"
}

class TrainedModel[E: Floating](val lossModel: LossModel, val params: Params[Tensor[E]]) {
case class TrainedModel[E: Floating](lossModel: LossModel, params: Params[Tensor[E]]) {

def buildResult(input: Expr[E]): Expr[E] =
buildResultStateful(input)._1
Expand All @@ -168,4 +178,10 @@ class TrainedModel[E: Floating](val lossModel: LossModel, val params: Params[Ten
(input, output) => buildLossStateful(input, output)

def outputShape(input: Shape): Shape = lossModel.model.outputShape(input)

def trainable: Boolean = lossModel.trainable
def makeTrainable(trainable: Boolean): TrainedModel[E] =
copy(lossModel = lossModel.makeTrainable(trainable))
def freeze: TrainedModel[E] = makeTrainable(false)
def unfreeze: TrainedModel[E] = makeTrainable(true)
}
Loading

0 comments on commit 4bc04d2

Please sign in to comment.