Skip to content

Commit

Permalink
Added fused Batch Norm (#99)
Browse files Browse the repository at this point in the history
* Added fused batch norm
  • Loading branch information
pashashiz authored Nov 20, 2023
1 parent d5bd54e commit 367e49b
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 90 deletions.
73 changes: 38 additions & 35 deletions src/main/scala/scanet/core/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {

def rank: Int = dims.size

def axis: List[Int] = dims.indices.toList
def axes: List[Int] = dims.indices.toList

def axisExcept(other: Int*): List[Int] = {
val indexedAxis = indexAxis(other)
(dims.indices.toSet -- indexedAxis.toSet).toList.sorted
def axesExcept(other: Int*): List[Int] = {
val indexedAxes = indexAxes(other)
(dims.indices.toSet -- indexedAxes.toSet).toList.sorted
}

def isScalar: Boolean = rank == 0
Expand Down Expand Up @@ -158,8 +158,8 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
def broadcastableAny(other: Shape): Boolean =
broadcastableBy(other) || other.broadcastableBy(this)

def broadcastableAxis(other: Shape): Seq[Int] = {
require(broadcastableAny(other), s"cannot find broadcastable axis for $this and $other")
def broadcastableAxes(other: Shape): Seq[Int] = {
require(broadcastableAny(other), s"cannot find broadcastable axes for $this and $other")
if (rank < other.rank) {
Seq()
} else {
Expand All @@ -179,62 +179,65 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
Shape(dimsResult)
}

def permute(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def permute(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
rank == indexedAxis.size,
rank == indexedAxes.size,
"the number of permutation indexes " +
s"should be equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
s"should be equal to rank $rank, but was (${axes.mkString(", ")})")
Shape(indexedAxes.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
}

def select(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def select(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of selected axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.map(get).toList)
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of selected axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
Shape(indexedAxes.map(get).toList)
}

def remove(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def remove(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of removed axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of removed axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
val filteredDims = dims.zipWithIndex
.filter {
case (_, i) =>
!indexedAxis.contains(i)
!indexedAxes.contains(i)
}
.map { case (dim, _) => dim }
Shape(filteredDims)
}

def updated(axis: Int, value: Int): Shape = updateAll(value)(axis)

def updateAll(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def updateAll(value: Int)(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of updated axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of updated axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
val updatedDims = dims.zipWithIndex.map {
case (dim, i) =>
if (indexedAxis.contains(i)) value else dim
if (indexedAxes.contains(i)) value else dim
}
Shape(updatedDims)
}

def updateAllExcept(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
val axisToUpdate = dims.indices.toSet -- indexedAxis.toSet
updateAll(value)(axisToUpdate.toList: _*)
def updateAllExcept(value: Int)(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
val axesToUpdate = dims.indices.toSet -- indexedAxes.toSet
updateAll(value)(axesToUpdate.toList: _*)
}

private def indexAxis(axis: Seq[Int]): Seq[Int] =
axis.map(a => if (a == -1) dims.size - 1 else a)
def indexAxes(axes: Seq[Int]): Seq[Int] =
axes.map(indexAxis)

def indexAxis(axis: Int): Int =
if (axis == -1) dims.size - 1 else axis

def minus(other: Shape): Shape = {
require(broadcastableAny(other), s"cannot $this - $other")
Expand Down
95 changes: 84 additions & 11 deletions src/main/scala/scanet/math/alg/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ case class Plus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
parentGrad.sum(shrinkLeftAxis).reshape(left.shape),
parentGrad.sum(shrinkRightAxis).reshape(right.shape))
Expand Down Expand Up @@ -74,8 +74,8 @@ case class Minus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
List(
parentGrad.sum(shrinkLeftAxis).reshape(left.shape),
-parentGrad.sum(shrinkRightAxis).reshape(right.shape))
Expand Down Expand Up @@ -111,8 +111,8 @@ case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
(right.cast[R] * parentGrad).sum(shrinkLeftAxis).reshape(left.shape),
(left.cast[R] * parentGrad).sum(shrinkRightAxis).reshape(right.shape))
Expand All @@ -137,7 +137,7 @@ case class Pow[A: Numeric](expr: Expr[A], exponent: Expr[Float]) extends Expr[A]
}
}

case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] {
case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self =>
override def name: String = "Sqrt"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = expr.shape
Expand All @@ -147,12 +147,46 @@ case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R]
List(local * parentGrad)
// val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R]
// List(local * parentGrad)
List(SqrtGrad(self.cast[R], parentGrad))
}
}
}

case class SqrtGrad[A: Numeric](sqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] {
override def name: String = "SqrtGrad"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = sqrt.shape
override def inputs: Seq[Expr[_]] = Seq(sqrt, parentGrad)
override def compiler: Compiler[A] = DefaultCompiler[A]()
}

case class Rsqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self =>
override def name: String = "Rsqrt"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = expr.shape
override def inputs: Seq[Expr[_]] = Seq(expr)
override def compiler: Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
// val local = (expr.cast[R] ^ -1.5f) * -0.5f.const.cast[R]
// List(local * parentGrad)
List(RsqrtGrad(self.cast[R], parentGrad))
}
}
}

case class RsqrtGrad[A: Numeric](rsqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] {
override def name: String = "RsqrtGrad"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = rsqrt.shape
override def inputs: Seq[Expr[_]] = Seq(rsqrt, parentGrad)
override def compiler: Compiler[A] = DefaultCompiler[A]()
}

case class Exp[A: Numeric](expr: Expr[A]) extends Expr[A] {
override def name: String = "Exp"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
Expand Down Expand Up @@ -182,8 +216,8 @@ case class Div[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
(parentGrad / right.cast[R]).sum(shrinkLeftAxis).reshape(left.shape),
(-left.cast[R] * parentGrad / right.sqr.cast[R])
Expand Down Expand Up @@ -452,6 +486,8 @@ trait AllKernels {

def sqrt[A: Numeric](expr: Expr[A]): Expr[A] = Sqrt(expr)

def rsqrt[A: Numeric](expr: Expr[A]): Expr[A] = Rsqrt(expr)

def sqrtZeroSafe[A: Numeric](out: Expr[A], epsilon: Expr[A]): Expr[A] =
sqrt(plus(out, epsilon))

Expand All @@ -466,6 +502,20 @@ trait AllKernels {
keepDims: Boolean = false): Expr[A] = Mean(expr, axis, keepDims)
def mean[A: Numeric](expr: Expr[A]): Expr[A] = mean(expr, 0 until expr.rank)

def moments[A: Numeric](
expr: Expr[A],
axis: Seq[Int],
keepDims: Boolean = false): (Expr[A], Expr[A]) = {
val m = mean(expr, axis, keepDims)
// try squared_difference, it has optimized kernel op
val v = mean((expr - m).sqr, axis, keepDims)
(m, v)
}

def moments[A: Numeric](
expr: Expr[A]): (Expr[A], Expr[A]) =
moments(expr, 0 until expr.rank)

def max[A: TensorType, C](left: Expr[A], right: C)(implicit c: Convertible[C, Expr[A]]): Expr[A] =
Max(left, c.convert(right))

Expand Down Expand Up @@ -597,6 +647,12 @@ object kernels extends AllKernels {
*/
def sqr: Expr[A] = pow(2.0f)

/** Computes reciprocal (inversed) of square root of x element-wise: `1 / sqrt(x))`
*
* @return tensor `^` -0.5
*/
def rsqrt: Expr[A] = f.rsqrt(expr)

/** Returns square root of the given tensor
*
* {{{Tensor.vector(1.0f, 4.0f, 9.0f).const.sqrt.eval should be(Tensor.vector(1.0f, 2.0f, 3.0f))}}}
Expand Down Expand Up @@ -676,6 +732,23 @@ object kernels extends AllKernels {
*/
def mean: Expr[A] = f.mean(expr)

/** Computes the frequency-weighted mean and variance across dimensions of a tensor.
*
* Reduces `(mean, variance)` along the dimensions given in `axis`.
* The rank of the tensor is reduced by 1 for each entry in `axis`.
*
* @param axis to sum
* @return tensors `(mean, variance)`
*/
def moments(axis: Seq[Int], keepDims: Boolean = false): (Expr[A], Expr[A]) =
f.moments(expr, axis, keepDims)

/** Computes the frequency-weighted mean and variance across all dimensions of a tensor.
* *
* @return tensors `(mean, variance)`
*/
def moments: (Expr[A], Expr[A]) = f.moments(expr)

/** Shuffle dimensions of `out` according to a permutation.
*
* {{{
Expand Down
Loading

0 comments on commit 367e49b

Please sign in to comment.