Skip to content

Commit

Permalink
revise sharedmemory attribute to take reguluar GB/MB specification
Browse files Browse the repository at this point in the history
  • Loading branch information
geertvandeweyer committed Nov 4, 2024
1 parent bcd4ef9 commit 502b89d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import java.security.MessageDigest
import org.apache.commons.lang3.builder.{ToStringBuilder, ToStringStyle}
import org.slf4j.{Logger, LoggerFactory}
import wdl4s.parser.MemoryUnit
import wom.format.MemorySize

/**
* Responsible for the creation of the job definition.
Expand Down Expand Up @@ -164,8 +165,8 @@ trait AwsBatchJobDefinitionBuilder {
efsMakeMD5: Boolean,
tagResources: Boolean,
logGroupName: String,
sharedMemorySize: Int): String = {
s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:${ulimits.map(_.toString).mkString(",")}:${efsDelocalize.toString}:${efsMakeMD5.toString}:${tagResources.toString}:$logGroupName"
sharedMemorySize: MemorySize): String = {
s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:${ulimits.map(_.toString).mkString(",")}:${efsDelocalize.toString}:${efsMakeMD5.toString}:${tagResources.toString}:$logGroupName:${sharedMemorySize.to(MemoryUnit.MB).amount.toInt}"
}

val environment = List.empty[KeyValuePair]
Expand Down Expand Up @@ -201,9 +202,8 @@ trait AwsBatchJobDefinitionBuilder {
efsMakeMD5,
tagResources,
logGroupName,
context.runtimeAttributes.sharedMemorySize.value
context.runtimeAttributes.sharedMemorySize
)

// To reuse job definition for gpu and gpu-runs, we will create a job definition that does not gpu requirements
// since aws batch does not allow you to set gpu as 0 when you dont need it. you will always need cpu and memory
(ContainerProperties.builder()
Expand All @@ -219,8 +219,8 @@ trait AwsBatchJobDefinitionBuilder {
.environment(environment.asJava)
.ulimits(ulimits.asJava)
.linuxParameters(
LinuxParameters.builder().sharedMemorySize(context.runtimeAttributes.sharedMemorySize.##).build()
),
LinuxParameters.builder().sharedMemorySize(context.runtimeAttributes.sharedMemorySize.to(MemoryUnit.MB).amount.toInt).build() // Convert MemorySize to MB
),
containerPropsName)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ import eu.timepit.refined.api.Refined
import eu.timepit.refined.numeric.Positive
import wom.RuntimeAttributesKeys
import wom.format.MemorySize
import wdl4s.parser.MemoryUnit
import wom.types._
import wom.values._
import com.typesafe.config.{ConfigException, ConfigValueFactory}

import scala.util.matching.Regex
import org.slf4j.{Logger, LoggerFactory}
import wom.RuntimeAttributesKeys.{GpuKey, sharedMemoryKey}
import wom.RuntimeAttributesKeys.{GpuKey } // , sharedMemoryKey}

import scala.util.{Failure, Success, Try}
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -94,7 +95,7 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
ulimits: Vector[Map[String, String]],
efsDelocalize: Boolean,
efsMakeMD5 : Boolean,
sharedMemorySize: Int Refined Positive,
sharedMemorySize: MemorySize,
logGroupName: String,
additionalTags: Map[String, String],
fileSystem: String= "s3",
Expand All @@ -110,7 +111,7 @@ object AwsBatchRuntimeAttributes {

val awsBatchEvaluateOnExitKey = "awsBatchEvaluateOnExit"

val defaultSharedMemorySize = WomInteger(64)
val defaultSharedMemorySize = MemorySize(64, MemoryUnit.MB)

private val awsBatchEvaluateOnExitDefault = WomArray(WomArrayType(WomMapType(WomStringType,WomStringType)), Vector(WomMap(Map.empty[WomValue, WomValue])))

Expand Down Expand Up @@ -179,9 +180,10 @@ object AwsBatchRuntimeAttributes {
noAddressValidationInstance
.withDefault(noAddressValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse NoAddressDefaultValue)

private def sharedMemorySizeValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Refined[Int, Positive]] = {
SharedMemorySizeValidation(sharedMemoryKey).withDefault(
SharedMemorySizeValidation(sharedMemoryKey).configDefaultWomValue(runtimeConfig).getOrElse(defaultSharedMemorySize)
private def sharedMemorySizeValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[MemorySize] = {
MemoryValidation.withDefaultMemory(
RuntimeAttributesKeys.sharedMemoryKey,
MemoryValidation.configDefaultString(RuntimeAttributesKeys.sharedMemoryKey, runtimeConfig) getOrElse defaultSharedMemorySize.toString
)
}

Expand Down Expand Up @@ -337,7 +339,7 @@ object AwsBatchRuntimeAttributes {
val efsDelocalize: Boolean = RuntimeAttributesValidation.extract(awsBatchefsDelocalizeValidation(runtimeAttrsConfig),validatedRuntimeAttributes)
val efsMakeMD5: Boolean = RuntimeAttributesValidation.extract(awsBatchefsMakeMD5Validation(runtimeAttrsConfig),validatedRuntimeAttributes)
val tagResources: Boolean = RuntimeAttributesValidation.extract(awsBatchtagResourcesValidation(runtimeAttrsConfig),validatedRuntimeAttributes)
val sharedMemorySize: Int Refined Positive = RuntimeAttributesValidation.extract(sharedMemorySizeValidation(runtimeAttrsConfig), validatedRuntimeAttributes)
val sharedMemorySize: MemorySize = RuntimeAttributesValidation.extract(sharedMemorySizeValidation(runtimeAttrsConfig), validatedRuntimeAttributes)

new AwsBatchRuntimeAttributes(
cpu,
Expand Down Expand Up @@ -708,12 +710,6 @@ class AwsBatchtagResourcesValidation(key: String) extends BooleanRuntimeAttribut
override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be a Boolean"
}

object SharedMemorySizeValidation {
def apply(key: String): SharedMemorySizeValidation = new SharedMemorySizeValidation(key)
}

class SharedMemorySizeValidation(key: String) extends PositiveIntRuntimeAttributesValidation(key)

object UlimitsValidation
extends RuntimeAttributesValidation[Vector[Map[String, String]]] {
override def key: String = AwsBatchRuntimeAttributes.UlimitsKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
val s3Outputs: Set[AwsBatchFileOutput] = Set(AwsBatchFileOutput("baa", "s3://bucket/somewhere/baa", DefaultPathBuilder.get("baa"), AwsBatchWorkingDisk()))

val cpu: Int Refined Positive = 2
val sharedMemorySize: Int Refined Positive = 64
val sharedMemorySize: MemorySize = "64 MB"

val runtimeAttributes: AwsBatchRuntimeAttributes = new AwsBatchRuntimeAttributes(
cpu = cpu,
Expand Down

0 comments on commit 502b89d

Please sign in to comment.