Skip to content

Commit

Permalink
feat: add nvidia MIG Settings
Browse files Browse the repository at this point in the history
  • Loading branch information
piyush-jena committed Nov 6, 2024
1 parent 276b8e8 commit a5f78d1
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ bottlerocket-template-helper = { path = "./bottlerocket-template-helper", versio

# Settings Models
bottlerocket-model-derive = { path = "./bottlerocket-settings-models/model-derive", version = "0.1" }
bottlerocket-modeled-types = { path = "./bottlerocket-settings-models/modeled-types", version = "0.6" }
bottlerocket-modeled-types = { path = "./bottlerocket-settings-models/modeled-types", version = "0.7" }
bottlerocket-scalar = { path = "./bottlerocket-settings-models/scalar", version = "0.1" }
bottlerocket-scalar-derive = { path = "./bottlerocket-settings-models/scalar-derive", version = "0.1" }
bottlerocket-string-impls-for = { path = "./bottlerocket-settings-models/string-impls-for", version = "0.1" }
Expand Down
2 changes: 1 addition & 1 deletion bottlerocket-settings-models/modeled-types/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bottlerocket-modeled-types"
version = "0.6.0"
version = "0.7.0"
authors = []
license = "Apache-2.0 OR MIT"
edition = "2021"
Expand Down
231 changes: 229 additions & 2 deletions bottlerocket-settings-models/modeled-types/src/kubernetes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1451,14 +1451,222 @@ mod test_hostname_override_source {

// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct MIGA100Profile {
inner: String,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ValidA100Policy {
#[serde(alias = "7")]
#[serde(alias = "1g.5gb")]
Profile1g5gb,
#[serde(alias = "3")]
#[serde(alias = "2g.10gb")]
Profile2g10gb,
#[serde(alias = "2")]
#[serde(alias = "3g.20gb")]
Profile3g20gb,
#[serde(alias = "1")]
#[serde(alias = "7g.40gb")]
Profile7g40gb,
}

impl TryFrom<&str> for MIGA100Profile {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
serde_plain::from_str::<ValidA100Policy>(input)
.context(error::InvalidMIGProfileSnafu { input })?;
Ok(MIGA100Profile {
inner: input.to_string(),
})
}
}

impl Default for MIGA100Profile {
fn default() -> Self {
MIGA100Profile {
inner: "7g.40gb".to_string(),
}
}
}

string_impls_for!(MIGA100Profile, "MIGA100Profile");

#[cfg(test)]
mod test_valid_a100_profile {
use super::MIGA100Profile;
use std::convert::TryFrom;

#[test]
fn valid_a100_profile() {
for ok in &[
"1g.5gb", "2g.10gb", "3g.20gb", "7g.40gb", "1", "2", "3", "7",
] {
assert!(MIGA100Profile::try_from(*ok).is_ok());
}
}

#[test]
fn invalid_a100_profile() {
assert!(MIGA100Profile::try_from("invalid").is_err());
assert!(MIGA100Profile::try_from("1000").is_err());
assert!(MIGA100Profile::try_from("1g.7gb").is_err());
}
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct MIGH100Profile {
inner: String,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ValidH100Policy {
#[serde(alias = "7")]
#[serde(alias = "1g.10gb")]
Profile1g10gb,
#[serde(alias = "4")]
#[serde(alias = "1g.20gb")]
Profile1g20gb,
#[serde(alias = "3")]
#[serde(alias = "2g.20gb")]
Profile2g20gb,
#[serde(alias = "2")]
#[serde(alias = "3g.40gb")]
Profile3g40gb,
#[serde(alias = "1")]
#[serde(alias = "7g.80gb")]
Profile7g80gb,
}

impl TryFrom<&str> for MIGH100Profile {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
serde_plain::from_str::<ValidH100Policy>(input)
.context(error::InvalidMIGProfileSnafu { input })?;
Ok(MIGH100Profile {
inner: input.to_string(),
})
}
}

impl Default for MIGH100Profile {
fn default() -> Self {
MIGH100Profile {
inner: "7g.80gb".to_string(),
}
}
}

string_impls_for!(MIGH100Profile, "MIGH100Profile");

#[cfg(test)]
mod test_valid_h100_profile {
use super::MIGH100Profile;
use std::convert::TryFrom;

#[test]
fn valid_h100_profile() {
for ok in &[
"1g.10gb", "1g.20gb", "2g.20gb", "3g.40gb", "7g.80gb", "1", "2", "3", "4", "7",
] {
assert!(MIGH100Profile::try_from(*ok).is_ok());
}
}

#[test]
fn invalid_h100_profile() {
assert!(MIGH100Profile::try_from("invalid").is_err());
assert!(MIGH100Profile::try_from("1000").is_err());
assert!(MIGH100Profile::try_from("1g.7gb").is_err());
}
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct MIGH200Profile {
inner: String,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize)]
#[serde(rename_all = "lowercase")]
enum ValidH200Policy {
#[serde(alias = "7")]
#[serde(alias = "1g.18gb")]
Profile1g10gb,
#[serde(alias = "4")]
#[serde(alias = "1g.35gb")]
Profile1g20gb,
#[serde(alias = "3")]
#[serde(alias = "2g.35gb")]
Profile2g20gb,
#[serde(alias = "2")]
#[serde(alias = "3g.71gb")]
Profile3g40gb,
#[serde(alias = "1")]
#[serde(alias = "7g.141gb")]
Profile7g141gb,
}

impl TryFrom<&str> for MIGH200Profile {
type Error = error::Error;

fn try_from(input: &str) -> Result<Self, Self::Error> {
serde_plain::from_str::<ValidH200Policy>(input)
.context(error::InvalidMIGProfileSnafu { input })?;
Ok(MIGH200Profile {
inner: input.to_string(),
})
}
}

impl Default for MIGH200Profile {
fn default() -> Self {
MIGH200Profile {
inner: "7g.141gb".to_string(),
}
}
}

string_impls_for!(MIGH200Profile, "MIGH200Profile");

#[cfg(test)]
mod test_valid_h200_profile {
use super::MIGH200Profile;
use std::convert::TryFrom;

#[test]
fn valid_h200_profile() {
for ok in &[
"1g.18gb", "1g.35gb", "2g.35gb", "3g.71gb", "7g.141gb", "1", "2", "3", "4", "7",
] {
assert!(MIGH200Profile::try_from(*ok).is_ok());
}
}

#[test]
fn invalid_h200_profile() {
assert!(MIGH200Profile::try_from("invalid").is_err());
assert!(MIGH200Profile::try_from("1000").is_err());
assert!(MIGH200Profile::try_from("1g.7gb").is_err());
}
}

// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=

/// NvidiaRuntimeSettings contains the container runtime settings for Nvidia gpu.
#[model(impl_default = true)]
pub struct NvidiaDevicePluginSettings {
pass_device_specs: bool,
device_id_strategy: NvidiaDeviceIdStrategy,
device_list_strategy: NvidiaDeviceListStrategy,
device_sharing_strategy: NvidiaDeviceSharingStrategy,
device_partitioning_strategy: NvidiaDevicePartitioningStrategy,
time_slicing: NvidiaTimeSlicingSettings,
mig: NvidiaMIGSettings,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -1489,6 +1697,21 @@ pub struct NvidiaTimeSlicingSettings {
fail_requests_greater_than_one: bool,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum NvidiaDevicePartitioningStrategy {
#[default]
None,
MIG,
}

#[model(impl_default = true)]
pub struct NvidiaMIGSettings {
profile_a100: MIGA100Profile,
profile_h100: MIGH100Profile,
profile_h200: MIGH200Profile,
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -1505,7 +1728,9 @@ mod tests {
device_id_strategy: Some(NvidiaDeviceIdStrategy::Uuid),
device_list_strategy: Some(NvidiaDeviceListStrategy::Envvar),
device_sharing_strategy: None,
time_slicing: None
device_partitioning_strategy: None,
time_slicing: None,
mig: None,
}
);
let results = serde_json::to_string(&nvidia_device_plugins).unwrap();
Expand All @@ -1524,7 +1749,9 @@ mod tests {
device_id_strategy: Some(NvidiaDeviceIdStrategy::Uuid),
device_list_strategy: Some(NvidiaDeviceListStrategy::Envvar),
device_sharing_strategy: Some(NvidiaDeviceSharingStrategy::TimeSlicing),
time_slicing: None
device_partitioning_strategy: None,
time_slicing: None,
mig: None,
}
);

Expand Down
6 changes: 6 additions & 0 deletions bottlerocket-settings-models/modeled-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ pub mod error {
#[snafu(display("Invalid ECS duration value '{}'", input))]
InvalidECSDurationValue { input: String },

#[snafu(display("Invalid MIG Profile value '{}'", input))]
InvalidMIGProfile {
input: String,
source: serde_plain::Error,
},

#[snafu(display("Could not parse '{}' as an integer", input))]
ParseInt {
input: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ impl SettingsModel for KubeletDevicePluginsV1 {
mod test {
use super::*;
use bottlerocket_modeled_types::{
NvidiaDeviceIdStrategy, NvidiaDeviceListStrategy, NvidiaDeviceSharingStrategy,
NvidiaTimeSlicingSettings,
MIGA100Profile, MIGH100Profile, MIGH200Profile, NvidiaDeviceIdStrategy,
NvidiaDeviceListStrategy, NvidiaDevicePartitioningStrategy, NvidiaDeviceSharingStrategy,
NvidiaMIGSettings, NvidiaTimeSlicingSettings,
};
use bounded_integer::BoundedI32;

Expand All @@ -59,7 +60,7 @@ mod test {

#[test]
fn test_serde_kubelet_device_plugins() {
let test_json = r#"{"nvidia":{"pass-device-specs":true,"device-id-strategy":"index","device-list-strategy":"volume-mounts","device-sharing-strategy":"time-slicing","time-slicing":{"replicas":2,"rename-by-default":true,"fail-requests-greater-than-one":true}}}"#;
let test_json = r#"{"nvidia":{"pass-device-specs":true,"device-id-strategy":"index","device-list-strategy":"volume-mounts","device-sharing-strategy":"time-slicing","device-partitioning-strategy":"mig","time-slicing":{"replicas":2,"rename-by-default":true,"fail-requests-greater-than-one":true},"mig":{"profile-a100":"1g.5gb","profile-h100":"7g.80gb","profile-h200":"7g.141gb"}}}"#;

let device_plugins: KubeletDevicePluginsV1 = serde_json::from_str(test_json).unwrap();
assert_eq!(
Expand All @@ -70,11 +71,17 @@ mod test {
device_id_strategy: Some(NvidiaDeviceIdStrategy::Index),
device_list_strategy: Some(NvidiaDeviceListStrategy::VolumeMounts),
device_sharing_strategy: Some(NvidiaDeviceSharingStrategy::TimeSlicing),
device_partitioning_strategy: Some(NvidiaDevicePartitioningStrategy::MIG),
time_slicing: Some(NvidiaTimeSlicingSettings {
replicas: Some(BoundedI32::new(2).unwrap()),
rename_by_default: Some(true),
fail_requests_greater_than_one: Some(true),
}),
mig: Some(NvidiaMIGSettings {
profile_a100: Some(MIGA100Profile::try_from("1g.5gb").unwrap()),
profile_h100: Some(MIGH100Profile::try_from("7g.80gb").unwrap()),
profile_h200: Some(MIGH200Profile::try_from("7g.141gb").unwrap()),
}),
}),
}
);
Expand Down
2 changes: 1 addition & 1 deletion bottlerocket-settings-models/settings-models/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bottlerocket-settings-models"
version = "0.6.0"
version = "0.7.0"
authors = ["Tom Kirchner <[email protected]>"]
license = "Apache-2.0 OR MIT"
edition = "2021"
Expand Down
1 change: 1 addition & 0 deletions deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ allow = [
"MIT",
# "OpenSSL",
# "Unlicense",
"Unicode-3.0",
"Zlib",
]

Expand Down

0 comments on commit a5f78d1

Please sign in to comment.