From b24236b9d79ccc572d8c3210c887831bc2b7296c Mon Sep 17 00:00:00 2001 From: Pavlo Pohrebnyi Date: Sun, 12 Nov 2023 13:01:55 +0200 Subject: [PATCH] Fixes --- src/main/scala/scanet/core/Shape.scala | 2 +- src/main/scala/scanet/math/alg/AllKernels.scala | 4 ---- src/main/scala/scanet/models/layer/Bias.scala | 2 +- src/test/scala/scanet/models/CNNSpec.scala | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/main/scala/scanet/core/Shape.scala b/src/main/scala/scanet/core/Shape.scala index 2493933..bc15ef0 100644 --- a/src/main/scala/scanet/core/Shape.scala +++ b/src/main/scala/scanet/core/Shape.scala @@ -173,7 +173,7 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] { def maxDims(other: Shape): Shape = { val maxRank = rank max other.rank val left = alignLeft(maxRank, 1) - val right = other.alignLeft(maxRank, 2) + val right = other.alignLeft(maxRank, 1) val dimsResult = left.dims.zip(right.dims) .map { case (l, r) => l max r } Shape(dimsResult) diff --git a/src/main/scala/scanet/math/alg/AllKernels.scala b/src/main/scala/scanet/math/alg/AllKernels.scala index 28b2c87..0c99fdf 100644 --- a/src/main/scala/scanet/math/alg/AllKernels.scala +++ b/src/main/scala/scanet/math/alg/AllKernels.scala @@ -98,10 +98,6 @@ case class Negate[A: Numeric](expr: Expr[A]) extends Expr[A] { } case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends Expr[A] { - if (!left.broadcastableAny(right)) { - println(s"$left * $right") - println(s"!!!") - } require( left.broadcastableAny(right), s"cannot multiply tensors with shapes ${left.shape} * ${right.shape}") diff --git a/src/main/scala/scanet/models/layer/Bias.scala b/src/main/scala/scanet/models/layer/Bias.scala index 79e5180..db065ed 100644 --- a/src/main/scala/scanet/models/layer/Bias.scala +++ b/src/main/scala/scanet/models/layer/Bias.scala @@ -6,7 +6,7 @@ import scanet.math.syntax.zeros import scanet.models.Aggregation.Avg import scanet.models.Initializer.Zeros import scanet.models.Regularization.Zero -import scanet.models.{Initializer, Model, ParamDef, Regularization} +import scanet.models.{Initializer, ParamDef, Regularization} import scanet.syntax._ /** A layer which sums up a bias vector (weights) with the input. diff --git a/src/test/scala/scanet/models/CNNSpec.scala b/src/test/scala/scanet/models/CNNSpec.scala index 16a3670..ca26e31 100644 --- a/src/test/scala/scanet/models/CNNSpec.scala +++ b/src/test/scala/scanet/models/CNNSpec.scala @@ -6,7 +6,7 @@ import scanet.core.Shape import scanet.estimators.accuracy import scanet.models.Activation._ import scanet.models.Loss._ -import scanet.models.layer.{Activate, BatchNorm, Conv2D, Dense, Flatten, Pool2D} +import scanet.models.layer.{Activate, Conv2D, Dense, Flatten, Pool2D} import scanet.optimizers.Adam import scanet.optimizers.Effect.{RecordAccuracy, RecordLoss} import scanet.optimizers.syntax._