Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pashashiz committed Nov 12, 2023
1 parent 4bc04d2 commit b24236b
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/scanet/core/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
def maxDims(other: Shape): Shape = {
val maxRank = rank max other.rank
val left = alignLeft(maxRank, 1)
val right = other.alignLeft(maxRank, 2)
val right = other.alignLeft(maxRank, 1)
val dimsResult = left.dims.zip(right.dims)
.map { case (l, r) => l max r }
Shape(dimsResult)
Expand Down
4 changes: 0 additions & 4 deletions src/main/scala/scanet/math/alg/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ case class Negate[A: Numeric](expr: Expr[A]) extends Expr[A] {
}

case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends Expr[A] {
if (!left.broadcastableAny(right)) {
println(s"$left * $right")
println(s"!!!")
}
require(
left.broadcastableAny(right),
s"cannot multiply tensors with shapes ${left.shape} * ${right.shape}")
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scanet/models/layer/Bias.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scanet.math.syntax.zeros
import scanet.models.Aggregation.Avg
import scanet.models.Initializer.Zeros
import scanet.models.Regularization.Zero
import scanet.models.{Initializer, Model, ParamDef, Regularization}
import scanet.models.{Initializer, ParamDef, Regularization}
import scanet.syntax._

/** A layer which sums up a bias vector (weights) with the input.
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/scanet/models/CNNSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import scanet.core.Shape
import scanet.estimators.accuracy
import scanet.models.Activation._
import scanet.models.Loss._
import scanet.models.layer.{Activate, BatchNorm, Conv2D, Dense, Flatten, Pool2D}
import scanet.models.layer.{Activate, Conv2D, Dense, Flatten, Pool2D}
import scanet.optimizers.Adam
import scanet.optimizers.Effect.{RecordAccuracy, RecordLoss}
import scanet.optimizers.syntax._
Expand Down

0 comments on commit b24236b

Please sign in to comment.