Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Commit

Permalink
Add some missing layers commonly used for image processing (#99)
Browse files Browse the repository at this point in the history
* Add AveragePooling2D

* Add GlobalMaxPooling2D

* Add SpatialDropout2D

* Add UpSampling2D

* Remove log4j.properties from kotlintest artifact

* Add GlobalAveragePooling2D and add Layer.ModelLayer for TF 1.15.0 support

* Use nearest interpolation if its null (for TF < v1.15)
  • Loading branch information
Octogonapus authored Nov 6, 2019
1 parent 1724f49 commit e2dda6d
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 4 deletions.
Binary file modified libraries/kotlintest-core-jvm-4.0.2631-SNAPSHOT.jar
Binary file not shown.
3 changes: 3 additions & 0 deletions test-util/test-util.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ dependencies {
}
})

// Needed by kotlintest
api(group = "com.github.wumpz", name = "diffutils", version = "2.2")

api(
group = "org.junit.jupiter",
name = "junit-jupiter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
}

is Layer.AveragePooling2D -> makeLayerCode(
"tf.keras.layers.AvgPool2D",
listOf(),
listOf(
"pool_size" to layer.poolSize,
"strides" to layer.strides,
"padding" to layer.padding.value,
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.Dense -> makeLayerCode(
"tf.keras.layers.Dense",
listOf(),
Expand Down Expand Up @@ -106,6 +118,24 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
).right()

is Layer.GlobalAveragePooling2D -> makeLayerCode(
"tf.keras.layers.GlobalAveragePooling2D",
listOf(),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.GlobalMaxPooling2D -> makeLayerCode(
"tf.keras.layers.GlobalMaxPooling2D",
listOf(),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.MaxPooling2D -> makeLayerCode(
"tf.keras.layers.MaxPooling2D",
listOf(),
Expand All @@ -118,6 +148,26 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
).right()

is Layer.SpatialDropout2D -> makeLayerCode(
"tf.keras.layers.SpatialDropout2D",
listOf(layer.rate.toString()),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.UpSampling2D -> makeLayerCode(
"tf.keras.layers.UpSampling2D",
listOf(),
listOf(
"size" to layer.size,
"data_format" to layer.dataFormat?.value,
"interpolation" to layer.interpolation.value,
"name" to layer.name
)
).right()

// TODO: Remove this
else -> "Cannot construct an unknown layer: $layer".left()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import edu.wpi.axon.tfdata.layer.Activation
import edu.wpi.axon.tfdata.layer.Constraint
import edu.wpi.axon.tfdata.layer.DataFormat
import edu.wpi.axon.tfdata.layer.Initializer
import edu.wpi.axon.tfdata.layer.Interpolation
import edu.wpi.axon.tfdata.layer.Layer
import edu.wpi.axon.tfdata.layer.PoolingPadding
import edu.wpi.axon.tfdata.layer.Regularizer
import edu.wpi.axon.tfdata.layer.Layer
import io.kotlintest.shouldBe
import io.mockk.every
import io.mockk.mockk
Expand Down Expand Up @@ -183,6 +184,57 @@ internal class DefaultLayerToCodeTest : KoinTestFixture() {
),
("""tf.keras.layers.Flatten(data_format="channels_first", name="name")""").right(),
null
),
Arguments.of(
Layer.AveragePooling2D(
"name",
None,
Right(Tuple2(2, 2)),
Left(3),
PoolingPadding.Valid,
DataFormat.ChannelsLast
),
Right("""tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=3, padding="valid", data_format="channels_last", name="name")"""),
null
),
Arguments.of(
Layer.GlobalMaxPooling2D(
"name",
None,
DataFormat.ChannelsFirst
),
Right("""tf.keras.layers.GlobalMaxPooling2D(data_format="channels_first", name="name")"""),
null
),
Arguments.of(
Layer.SpatialDropout2D(
"name",
None,
0.2,
null
),
Right("""tf.keras.layers.SpatialDropout2D(0.2, data_format=None, name="name")"""),
null
),
Arguments.of(
Layer.UpSampling2D(
"name",
None,
Right(Tuple2(2, 2)),
null,
Interpolation.Nearest
),
Right("""tf.keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation="nearest", name="name")"""),
null
),
Arguments.of(
Layer.GlobalAveragePooling2D(
"name",
None,
DataFormat.ChannelsLast
),
Right("""tf.keras.layers.GlobalAveragePooling2D(data_format="channels_last", name="name")"""),
null
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package edu.wpi.axon.tfdata.layer

/**
* Values for the `interpolation` parameter for sampling-type layers.
*/
enum class Interpolation(val value: String) {
Nearest("nearest"), Bilinear("bilinear")
}
71 changes: 71 additions & 0 deletions tf-data/src/main/kotlin/edu/wpi/axon/tfdata/layer/Layer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ sealed class Layer {
override val inputs: Option<Set<String>>
) : Layer()

/**
* A layer that contains an entire model inside it.
*
* @param model The model that acts as this layer.
*/
data class ModelLayer(
override val name: String,
override val inputs: Option<Set<String>>,
val model: Model
) : Layer()

/**
* A layer that accepts input data and has no parameters.
*
Expand Down Expand Up @@ -141,6 +152,18 @@ sealed class Layer {
val virtualBatchSize: Int? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/AveragePooling2D
*/
data class AveragePooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val poolSize: Either<Int, Tuple2<Int, Int>> = Right(Tuple2(2, 2)),
val strides: Either<Int, Tuple2<Int, Int>>? = null,
val padding: PoolingPadding = PoolingPadding.Valid,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Conv2D
*/
Expand Down Expand Up @@ -197,6 +220,24 @@ sealed class Layer {
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/GlobalAveragePooling2D
*/
data class GlobalAveragePooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val dataFormat: DataFormat?
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/GlobalMaxPool2D
*/
data class GlobalMaxPooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/MaxPool2D
*/
Expand All @@ -208,4 +249,34 @@ sealed class Layer {
val padding: PoolingPadding = PoolingPadding.Valid,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/SpatialDropout2D
*/
data class SpatialDropout2D(
override val name: String,
override val inputs: Option<Set<String>>,
val rate: Double,
val dataFormat: DataFormat? = null
) : Layer() {

init {
require(rate in 0.0..1.0) {
"rate ($rate) was outside the allowed range of [0, 1]."
}
}
}

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/UpSampling2D
*
* Bug: TF does not export a value for [interpolation].
*/
data class UpSampling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val size: Either<Int, Tuple2<Int, Int>> = Right(Tuple2(2, 2)),
val dataFormat: DataFormat? = null,
val interpolation: Interpolation = Interpolation.Nearest
) : Layer()
}
12 changes: 12 additions & 0 deletions tf-data/src/test/kotlin/edu/wpi/axon/tfdata/layer/LayerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ internal class LayerTest {

@Test
fun `dropout with invalid rate`() {
shouldThrow<IllegalArgumentException> { Layer.Dropout("", None, -0.1) }
shouldThrow<IllegalArgumentException> { Layer.Dropout("", None, 1.2) }
}

@Test
fun `spatialdropout2d with valid rate`() {
shouldNotThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, 0.5) }
}

@Test
fun `spatialdropout2d with invalid rate`() {
shouldThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, -0.1) }
shouldThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, 1.2) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import edu.wpi.axon.tfdata.layer.Activation
import edu.wpi.axon.tfdata.layer.Constraint
import edu.wpi.axon.tfdata.layer.DataFormat
import edu.wpi.axon.tfdata.layer.Initializer
import edu.wpi.axon.tfdata.layer.Interpolation
import edu.wpi.axon.tfdata.layer.Layer
import edu.wpi.axon.tfdata.layer.PoolingPadding
import edu.wpi.axon.tfdata.layer.Regularizer
Expand Down Expand Up @@ -129,6 +130,8 @@ class LoadLayersFromHDF5(
val json = data["config"] as JsonObject
val name = json["name"] as String
return when (className) {
"Sequential", "Model" -> Layer.ModelLayer(name, data.inboundNodes(), parseModel(data))

"InputLayer" -> Layer.InputLayer(
name,
(json["batch_input_shape"] as JsonArray<Int?>).toList().let {
Expand Down Expand Up @@ -166,8 +169,16 @@ class LoadLayersFromHDF5(
json["virtual_batch_size"] as Int?
)

"Conv2D"
-> Layer.Conv2D(
"AvgPool2D", "AveragePooling2D" -> Layer.AveragePooling2D(
name,
data.inboundNodes(),
json["pool_size"].tuple2OrInt(),
json["strides"].tuple2OrIntOrNull(),
json["padding"].poolingPadding(),
json["data_format"].dataFormatOrNull()
)

"Conv2D" -> Layer.Conv2D(
name,
data.inboundNodes(),
json["filters"] as Int,
Expand Down Expand Up @@ -208,6 +219,18 @@ class LoadLayersFromHDF5(
json["data_format"].dataFormatOrNull()
)

"GlobalAveragePooling2D", "GlobalAvgPool2D" -> Layer.GlobalAveragePooling2D(
name,
data.inboundNodes(),
json["data_format"].dataFormatOrNull()
)

"GlobalMaxPooling2D", "GlobalMaxPool2D" -> Layer.GlobalMaxPooling2D(
name,
data.inboundNodes(),
json["data_format"].dataFormatOrNull()
)

"MaxPool2D", "MaxPooling2D" -> Layer.MaxPooling2D(
name,
data.inboundNodes(),
Expand All @@ -217,6 +240,21 @@ class LoadLayersFromHDF5(
json["data_format"].dataFormatOrNull()
)

"SpatialDropout2D" -> Layer.SpatialDropout2D(
name,
data.inboundNodes(),
json["rate"].double(),
json["data_format"].dataFormatOrNull()
)

"UpSampling2D" -> Layer.UpSampling2D(
name,
data.inboundNodes(),
json["size"].tuple2OrInt(),
json["data_format"].dataFormatOrNull(),
json["interpolation"].interpolation()
)

else -> Layer.UnknownLayer(
name,
data.inboundNodes()
Expand Down Expand Up @@ -363,6 +401,13 @@ private fun Any?.dataFormatOrNull(): DataFormat? = when (this as? String) {
else -> throw IllegalArgumentException("Not convertible: $this")
}

private fun Any?.interpolation(): Interpolation = when (this as? String) {
// Null in versions < v1.15.0 (TF bug). Use nearest as the default
null, "nearest" -> Interpolation.Nearest
"bilinear" -> Interpolation.Bilinear
else -> throw IllegalArgumentException("Not convertible: $this")
}

private fun Any?.tuple2OrInt(): Either<Int, Tuple2<Int, Int>> = when {
this is Int -> Left(this)

Expand Down
Loading

0 comments on commit e2dda6d

Please sign in to comment.