From 6e14de4e6bac7777b09b23bb6cb223590a5a6251 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Thu, 23 Oct 2025 18:17:57 +0800 Subject: [PATCH 1/8] feat: add Triton model config specification in TensorModelConfig Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- lib/llm/build.rs | 4 +++- lib/llm/src/grpc/service/kserve.rs | 30 ++++++++++++++++++++++++++++++ lib/llm/src/protocols/tensor.rs | 4 ++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/lib/llm/build.rs b/lib/llm/build.rs index e116eca68b..deebc7a5a6 100644 --- a/lib/llm/build.rs +++ b/lib/llm/build.rs @@ -44,7 +44,9 @@ fn main() -> Result<(), Box> { } fn build_protos() -> Result<(), Box> { - tonic_build::compile_protos("src/grpc/protos/kserve.proto")?; + tonic_build::configure() + .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") + .compile_protos(&["kserve.proto"], &["src/grpc/protos"])?; Ok(()) } diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index bf89dc2fa2..33b3cedfe9 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -418,6 +418,30 @@ impl GrpcInferenceService for KserveService { if card.model_type.supports_tensor() { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let model_config : ModelConfig = serde_json::from_value(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + return Ok(Response::new(ModelMetadataResponse { + name: model_config.name, + versions: vec!["1".to_string()], + platform: model_config.platform, + inputs: model_config.input.iter().map(|input| inference::model_metadata_response::TensorMetadata { + name: input.name.clone(), + datatype: match inference::DataType::try_from(input.data_type) { + Ok(dt) => dt.as_str_name().to_string(), + Err(_) => "TYPE_INVALID".to_string(), + }, + shape: input.dims.clone(), + }).collect(), + outputs: model_config.output.iter().map(|output| inference::model_metadata_response::TensorMetadata { + name: output.name.clone(), + datatype: match inference::DataType::try_from(output.data_type) { + Ok(dt) => dt.as_str_name().to_string(), + Err(_) => "TYPE_INVALID".to_string(), + }, + shape: output.dims.clone(), + }).collect(), + })); + } return Ok(Response::new(ModelMetadataResponse { name: tensor_model_config.name.clone(), versions: vec!["1".to_string()], @@ -499,6 +523,12 @@ impl GrpcInferenceService for KserveService { if card.model_type.supports_tensor() { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let model_config : ModelConfig = serde_json::from_value(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + return Ok(Response::new(ModelConfigResponse { + config: Some(model_config), + })); + } let model_config = ModelConfig { name: tensor_model_config.name.clone(), platform: "dynamo".to_string(), diff --git a/lib/llm/src/protocols/tensor.rs b/lib/llm/src/protocols/tensor.rs index fc19f5a32d..a43a637210 100644 --- a/lib/llm/src/protocols/tensor.rs +++ b/lib/llm/src/protocols/tensor.rs @@ -124,6 +124,10 @@ pub struct TensorModelConfig { pub name: String, pub inputs: Vec, pub outputs: Vec, + // Optional Triton model config in serialized protobuf string, + // if provided, it supercedes the basic model config defined above. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub triton_model_config: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] From 5eef909203c75eaf96aafaf162045fdc54103eef Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 13:31:07 +0800 Subject: [PATCH 2/8] test: add unit test for using Triton model config Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- lib/llm/src/protocols/tensor.rs | 11 ++++ lib/llm/tests/kserve_service.rs | 92 +++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/lib/llm/src/protocols/tensor.rs b/lib/llm/src/protocols/tensor.rs index a43a637210..bf59a5350c 100644 --- a/lib/llm/src/protocols/tensor.rs +++ b/lib/llm/src/protocols/tensor.rs @@ -130,6 +130,17 @@ pub struct TensorModelConfig { pub triton_model_config: Option, } +impl Default for TensorModelConfig { + fn default() -> Self { + Self { + name: "".to_string(), + inputs: vec![], + outputs: vec![], + triton_model_config: None, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Tensor { pub metadata: TensorMetadata, diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index e584491933..eee210b46d 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -359,6 +359,7 @@ pub mod kserve_test { ModelInfo = 8994, TensorModel = 8995, TensorModelTypes = 8996, + TritonModelConfig = 8997, } #[rstest] @@ -1169,6 +1170,7 @@ pub mod kserve_test { data_type: tensor::DataType::Bool, shape: vec![-1], }], + triton_model_config: None, }), ..Default::default() }; @@ -1202,6 +1204,95 @@ pub mod kserve_test { ); } + #[rstest] + #[tokio::test] + async fn test_triton_model_config( + #[with(TestPort::TritonModelConfig as u16)] service_with_engines: ( + KserveService, + Arc, + Arc, + Arc, + ), + ) { + // start server + let _running = RunningService::spawn(service_with_engines.0.clone()); + + let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await; + + let model_name = "tensor"; + let expected_model_config = inference::ModelConfig { + name: model_name.to_string(), + platform: "custom".to_string(), + backend: "custom".to_string(), + input: vec![inference::ModelInput { + name: "input".to_string(), + data_type: DataType::TypeInt32 as i32, + dims: vec![1], + optional: false, + ..Default::default() + }, + inference::ModelInput { + name: "optional_input".to_string(), + data_type: DataType::TypeInt32 as i32, + dims: vec![1], + optional: true, + ..Default::default() + }], + output: vec![inference::ModelOutput { + name: "output".to_string(), + data_type: DataType::TypeBool as i32, + dims: vec![-1], + ..Default::default() + }], + model_transaction_policy: Some(inference::ModelTransactionPolicy { + decoupled: true, + }), + ..Default::default() + }; + + // Register a tensor model + let mut card = ModelDeploymentCard::with_name_only(model_name); + card.model_type = ModelType::TensorBased; + card.model_input = ModelInput::Tensor; + card.runtime_config = ModelRuntimeConfig { + tensor_model_config: Some(tensor::TensorModelConfig { + triton_model_config: Some(serde_json::to_value(expected_model_config.clone()).unwrap()), + ..Default::default() + }), + ..Default::default() + }; + let tensor = Arc::new(TensorEngine {}); + service_with_engines + .0 + .model_manager() + .add_tensor_model("tensor", card.mdcsum(), tensor.clone()) + .unwrap(); + let _ = service_with_engines + .0 + .model_manager() + .save_model_card("key", card); + + // success config + let request = tonic::Request::new(ModelConfigRequest { + name: model_name.into(), + version: "".into(), + }); + + let response = client + .model_config(request) + .await + .unwrap() + .into_inner() + .config; + let Some(config) = response else { + panic!("Expected Some(config), got None"); + }; + assert_eq!( + config, expected_model_config, + "Expected same model config to be returned", + ); + } + #[rstest] #[tokio::test] async fn test_tensor_infer( @@ -1299,6 +1390,7 @@ pub mod kserve_test { data_type: tensor::DataType::Bool, shape: vec![-1], }], + triton_model_config: None, }), ..Default::default() }; From a4b0fdd7fa72ee1079c70c4516649cd08f1ae763 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 19:13:09 +0800 Subject: [PATCH 3/8] fix: fix message format for passing Triton model config Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- Cargo.lock | 1 + lib/bindings/python/Cargo.lock | 1 + lib/llm/Cargo.toml | 1 + lib/llm/src/grpc/service/kserve.rs | 9 +++-- lib/llm/src/protocols/tensor.rs | 3 +- lib/llm/tests/kserve_service.rs | 8 ++++- tests/frontend/grpc/echo_tensor_worker.py | 36 ++++++++++++++++--- .../grpc/test_tensor_mocker_engine.py | 1 + tests/frontend/grpc/triton_echo_client.py | 14 ++++++++ 9 files changed, 65 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c16ec8eaeb..27ca11131f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2136,6 +2136,7 @@ dependencies = [ "async_zmq", "axum 0.8.4", "axum-server", + "base64 0.22.1", "bitflags 2.9.4", "blake3", "bs62", diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 3992bcd453..fa0007b53f 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1466,6 +1466,7 @@ dependencies = [ "async_zmq", "axum", "axum-server", + "base64 0.22.1", "bitflags 2.9.3", "blake3", "bs62", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 982b128e5b..44546c9501 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -116,6 +116,7 @@ utoipa-swagger-ui = { version = "9.0", features = ["axum"] } tonic = { version = "0.13.1" } # Request prost specifically so tonic-build properly compiles protobuf message prost = { version = "0.13.5" } +base64 = "0.22.1" # tokenizers tokenizers = { version = "0.21.4", default-features = false, features = [ diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 33b3cedfe9..8ce71cc543 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -39,6 +39,9 @@ use inference::{ ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse, }; +use base64::engine::{Engine, general_purpose}; +use prost::Message; + /// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint /// for it as part of HTTP service. Should we always start HTTP service up /// for non-inference? @@ -419,7 +422,8 @@ impl GrpcInferenceService for KserveService { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let model_config : ModelConfig = serde_json::from_value(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; + let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; return Ok(Response::new(ModelMetadataResponse { name: model_config.name, versions: vec!["1".to_string()], @@ -524,7 +528,8 @@ impl GrpcInferenceService for KserveService { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let model_config : ModelConfig = serde_json::from_value(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; + let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); diff --git a/lib/llm/src/protocols/tensor.rs b/lib/llm/src/protocols/tensor.rs index bf59a5350c..7525d6b91c 100644 --- a/lib/llm/src/protocols/tensor.rs +++ b/lib/llm/src/protocols/tensor.rs @@ -126,8 +126,9 @@ pub struct TensorModelConfig { pub outputs: Vec, // Optional Triton model config in serialized protobuf string, // if provided, it supercedes the basic model config defined above. + // The string is base64-encoded to ensure safe transport over text-based protocols. #[serde(default, skip_serializing_if = "Option::is_none")] - pub triton_model_config: Option, + pub triton_model_config: Option, } impl Default for TensorModelConfig { diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index eee210b46d..4548ff2583 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -41,6 +41,8 @@ pub mod kserve_test { use tonic::{Request, Response, transport::Channel}; use dynamo_async_openai::types::Prompt; + use prost::Message; + use base64::engine::{Engine, general_purpose}; struct SplitEngine {} @@ -1250,13 +1252,17 @@ pub mod kserve_test { ..Default::default() }; + let mut buf = vec![]; + expected_model_config.encode(&mut buf).unwrap(); + let buf = general_purpose::STANDARD.encode(&buf); + // Register a tensor model let mut card = ModelDeploymentCard::with_name_only(model_name); card.model_type = ModelType::TensorBased; card.model_input = ModelInput::Tensor; card.runtime_config = ModelRuntimeConfig { tensor_model_config: Some(tensor::TensorModelConfig { - triton_model_config: Some(serde_json::to_value(expected_model_config.clone()).unwrap()), + triton_model_config: Some(buf), ..Default::default() }), ..Default::default() diff --git a/tests/frontend/grpc/echo_tensor_worker.py b/tests/frontend/grpc/echo_tensor_worker.py index d66505801f..04ee806131 100644 --- a/tests/frontend/grpc/echo_tensor_worker.py +++ b/tests/frontend/grpc/echo_tensor_worker.py @@ -4,7 +4,10 @@ # Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker. +from base64 import b64encode + import uvloop +from model_config_pb2 import ModelConfig from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import DistributedRuntime, dynamo_worker @@ -17,12 +20,32 @@ async def echo_tensor_worker(runtime: DistributedRuntime): endpoint = component.endpoint("generate") + triton_model_config = ModelConfig() + triton_model_config.name = "echo" + triton_model_config.platform = "custom" + input_tensor = triton_model_config.input.add() + input_tensor.name = "input" + input_tensor.data_type = "TYPE_STRING" + input_tensor.dims.extend([-1]) + optional_input_tensor = triton_model_config.input.add() + optional_input_tensor.name = "optional_input" + optional_input_tensor.data_type = "TYPE_INT32" + optional_input_tensor.dims.extend([-1]) + optional_input_tensor.optional = True + output_tensor = triton_model_config.output.add() + output_tensor.name = "dummy_output" + output_tensor.data_type = "TYPE_STRING" + output_tensor.dims.extend([-1]) + triton_model_config.model_transaction_policy.decoupled = True + + # Serialize and base64-encode the Triton model config + b64_config = b64encode(triton_model_config.SerializeToString()) + model_config = { - "name": "echo", - "inputs": [ - {"name": "dummy_input", "data_type": "Bytes", "shape": [-1]}, - ], - "outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}], + "name": "", + "inputs": [], + "outputs": [], + "triton_model_config": b64_config, } runtime_config = ModelRuntimeConfig() runtime_config.set_tensor_model_config(model_config) @@ -45,6 +68,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime): async def generate(request, context): + # [NOTE] gluo: currently there is no frontend side + # validation between model config and actual request, + # so any request will reach here and be echoed back. print(f"Echoing request: {request}") yield {"model": request["model"], "tensors": request["tensors"]} diff --git a/tests/frontend/grpc/test_tensor_mocker_engine.py b/tests/frontend/grpc/test_tensor_mocker_engine.py index 6bb7ff2c09..862f084ea0 100644 --- a/tests/frontend/grpc/test_tensor_mocker_engine.py +++ b/tests/frontend/grpc/test_tensor_mocker_engine.py @@ -120,3 +120,4 @@ def start_services(request, runtime_services): @pytest.mark.model(TEST_MODEL) def test_echo() -> None: triton_echo_client.run_infer() + triton_echo_client.get_config() diff --git a/tests/frontend/grpc/triton_echo_client.py b/tests/frontend/grpc/triton_echo_client.py index cc2cb27167..3e94333ab5 100644 --- a/tests/frontend/grpc/triton_echo_client.py +++ b/tests/frontend/grpc/triton_echo_client.py @@ -43,3 +43,17 @@ def run_infer(): assert np.array_equal(input0_data, output0_data) assert np.array_equal(input1_data, output1_data) + + +def get_config(): + server_url = "localhost:8000" + try: + triton_client = grpcclient.InferenceServerClient(url=server_url) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit() + + model_name = "echo" + response = triton_client.get_model_config(model_name=model_name) + # Check one of the field that can only be set by providing Triton model config + assert response.config.model_transaction_policy.decoupled From 67d14e6da9b65440c46baa6428cad471d41360e0 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 19:18:09 +0800 Subject: [PATCH 4/8] fix: where to import model config Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- tests/frontend/grpc/echo_tensor_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/frontend/grpc/echo_tensor_worker.py b/tests/frontend/grpc/echo_tensor_worker.py index 04ee806131..c13f75049d 100644 --- a/tests/frontend/grpc/echo_tensor_worker.py +++ b/tests/frontend/grpc/echo_tensor_worker.py @@ -7,7 +7,10 @@ from base64 import b64encode import uvloop -from model_config_pb2 import ModelConfig + +# Knowing the test will be run in environment that has tritonclient installed, +# which contain the generated file equivalent to model_config.proto. +from tritonclient.grpc.model_config_pb2 import ModelConfig from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import DistributedRuntime, dynamo_worker From 8b9f4e668e010f292d1524c43711d6742b65f399 Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 19:29:40 +0800 Subject: [PATCH 5/8] style: style Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- lib/llm/src/grpc/service/kserve.rs | 84 ++++++++++++++++++++++-------- lib/llm/tests/kserve_service.rs | 22 ++++---- 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 8ce71cc543..38e7359e51 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -421,29 +421,55 @@ impl GrpcInferenceService for KserveService { if card.model_type.supports_tensor() { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { - if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; - let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + if let Some(triton_model_config) = + tensor_model_config.triton_model_config.as_ref() + { + let bytes = general_purpose::STANDARD + .decode(triton_model_config.clone()) + .map_err(|e| { + Status::invalid_argument(format!( + "Failed to decode base64 model config: {}", + e + )) + })?; + let model_config = ModelConfig::decode(&*bytes).map_err(|e| { + Status::invalid_argument(format!( + "Failed to deserialize model config: {}", + e + )) + })?; return Ok(Response::new(ModelMetadataResponse { name: model_config.name, versions: vec!["1".to_string()], platform: model_config.platform, - inputs: model_config.input.iter().map(|input| inference::model_metadata_response::TensorMetadata { - name: input.name.clone(), - datatype: match inference::DataType::try_from(input.data_type) { - Ok(dt) => dt.as_str_name().to_string(), - Err(_) => "TYPE_INVALID".to_string(), - }, - shape: input.dims.clone(), - }).collect(), - outputs: model_config.output.iter().map(|output| inference::model_metadata_response::TensorMetadata { - name: output.name.clone(), - datatype: match inference::DataType::try_from(output.data_type) { - Ok(dt) => dt.as_str_name().to_string(), - Err(_) => "TYPE_INVALID".to_string(), - }, - shape: output.dims.clone(), - }).collect(), + inputs: model_config + .input + .iter() + .map(|input| inference::model_metadata_response::TensorMetadata { + name: input.name.clone(), + datatype: match inference::DataType::try_from(input.data_type) { + Ok(dt) => dt.as_str_name().to_string(), + Err(_) => "TYPE_INVALID".to_string(), + }, + shape: input.dims.clone(), + }) + .collect(), + outputs: model_config + .output + .iter() + .map( + |output| inference::model_metadata_response::TensorMetadata { + name: output.name.clone(), + datatype: match inference::DataType::try_from( + output.data_type, + ) { + Ok(dt) => dt.as_str_name().to_string(), + Err(_) => "TYPE_INVALID".to_string(), + }, + shape: output.dims.clone(), + }, + ) + .collect(), })); } return Ok(Response::new(ModelMetadataResponse { @@ -527,9 +553,23 @@ impl GrpcInferenceService for KserveService { if card.model_type.supports_tensor() { if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() { - if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; - let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + if let Some(triton_model_config) = + tensor_model_config.triton_model_config.as_ref() + { + let bytes = general_purpose::STANDARD + .decode(triton_model_config.clone()) + .map_err(|e| { + Status::invalid_argument(format!( + "Failed to decode base64 model config: {}", + e + )) + })?; + let model_config = ModelConfig::decode(&*bytes).map_err(|e| { + Status::invalid_argument(format!( + "Failed to deserialize model config: {}", + e + )) + })?; return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index 4548ff2583..0130ba9597 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -40,9 +40,9 @@ pub mod kserve_test { use tokio::time::timeout; use tonic::{Request, Response, transport::Channel}; + use base64::engine::{Engine, general_purpose}; use dynamo_async_openai::types::Prompt; use prost::Message; - use base64::engine::{Engine, general_purpose}; struct SplitEngine {} @@ -1226,7 +1226,8 @@ pub mod kserve_test { name: model_name.to_string(), platform: "custom".to_string(), backend: "custom".to_string(), - input: vec![inference::ModelInput { + input: vec![ + inference::ModelInput { name: "input".to_string(), data_type: DataType::TypeInt32 as i32, dims: vec![1], @@ -1239,16 +1240,15 @@ pub mod kserve_test { dims: vec![1], optional: true, ..Default::default() - }], + }, + ], output: vec![inference::ModelOutput { - name: "output".to_string(), - data_type: DataType::TypeBool as i32, - dims: vec![-1], - ..Default::default() - }], - model_transaction_policy: Some(inference::ModelTransactionPolicy { - decoupled: true, - }), + name: "output".to_string(), + data_type: DataType::TypeBool as i32, + dims: vec![-1], + ..Default::default() + }], + model_transaction_policy: Some(inference::ModelTransactionPolicy { decoupled: true }), ..Default::default() }; From efb015d6ef2c6a4dea7901d7b8f2b5ee3ec1506b Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 24 Oct 2025 20:34:08 +0800 Subject: [PATCH 6/8] chore: address comment Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- Cargo.lock | 1 - lib/bindings/python/Cargo.lock | 1 - lib/llm/Cargo.toml | 1 - lib/llm/src/grpc/service/kserve.rs | 23 ++++--------------- lib/llm/src/protocols/tensor.rs | 5 ++--- lib/llm/tests/kserve_service.rs | 2 -- tests/frontend/grpc/echo_tensor_worker.py | 27 +++++++++++------------ 7 files changed, 19 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1f397e3c03..d7b47e8ef2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2136,7 +2136,6 @@ dependencies = [ "async_zmq", "axum 0.8.4", "axum-server", - "base64 0.22.1", "bincode", "bitflags 2.9.4", "blake3", diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index a64f8e73cd..4586591b15 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1466,7 +1466,6 @@ dependencies = [ "async_zmq", "axum", "axum-server", - "base64 0.22.1", "bincode", "bitflags 2.9.3", "blake3", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index d0868eb7eb..d74056150b 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -117,7 +117,6 @@ utoipa-swagger-ui = { version = "9.0", features = ["axum"] } tonic = { version = "0.13.1" } # Request prost specifically so tonic-build properly compiles protobuf message prost = { version = "0.13.5" } -base64 = "0.22.1" # tokenizers tokenizers = { version = "0.21.4", default-features = false, features = [ diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 38e7359e51..7ba2f41d91 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -39,7 +39,6 @@ use inference::{ ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse, }; -use base64::engine::{Engine, general_purpose}; use prost::Message; /// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint @@ -424,20 +423,13 @@ impl GrpcInferenceService for KserveService { if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let bytes = general_purpose::STANDARD - .decode(triton_model_config.clone()) + let model_config = ModelConfig::decode(triton_model_config.as_slice()) .map_err(|e| { Status::invalid_argument(format!( - "Failed to decode base64 model config: {}", + "Failed to deserialize model config: {}", e )) })?; - let model_config = ModelConfig::decode(&*bytes).map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; return Ok(Response::new(ModelMetadataResponse { name: model_config.name, versions: vec!["1".to_string()], @@ -556,20 +548,13 @@ impl GrpcInferenceService for KserveService { if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let bytes = general_purpose::STANDARD - .decode(triton_model_config.clone()) + let model_config = ModelConfig::decode(triton_model_config.as_slice()) .map_err(|e| { Status::invalid_argument(format!( - "Failed to decode base64 model config: {}", + "Failed to deserialize model config: {}", e )) })?; - let model_config = ModelConfig::decode(&*bytes).map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); diff --git a/lib/llm/src/protocols/tensor.rs b/lib/llm/src/protocols/tensor.rs index 7525d6b91c..9c117dc80c 100644 --- a/lib/llm/src/protocols/tensor.rs +++ b/lib/llm/src/protocols/tensor.rs @@ -125,10 +125,9 @@ pub struct TensorModelConfig { pub inputs: Vec, pub outputs: Vec, // Optional Triton model config in serialized protobuf string, - // if provided, it supercedes the basic model config defined above. - // The string is base64-encoded to ensure safe transport over text-based protocols. + // if provided, it supersedes the basic model config defined above. #[serde(default, skip_serializing_if = "Option::is_none")] - pub triton_model_config: Option, + pub triton_model_config: Option>, } impl Default for TensorModelConfig { diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index 0130ba9597..e405c9961f 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -40,7 +40,6 @@ pub mod kserve_test { use tokio::time::timeout; use tonic::{Request, Response, transport::Channel}; - use base64::engine::{Engine, general_purpose}; use dynamo_async_openai::types::Prompt; use prost::Message; @@ -1254,7 +1253,6 @@ pub mod kserve_test { let mut buf = vec![]; expected_model_config.encode(&mut buf).unwrap(); - let buf = general_purpose::STANDARD.encode(&buf); // Register a tensor model let mut card = ModelDeploymentCard::with_name_only(model_name); diff --git a/tests/frontend/grpc/echo_tensor_worker.py b/tests/frontend/grpc/echo_tensor_worker.py index c13f75049d..1cc2ea66b7 100644 --- a/tests/frontend/grpc/echo_tensor_worker.py +++ b/tests/frontend/grpc/echo_tensor_worker.py @@ -4,13 +4,10 @@ # Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker. -from base64 import b64encode - -import uvloop - # Knowing the test will be run in environment that has tritonclient installed, # which contain the generated file equivalent to model_config.proto. -from tritonclient.grpc.model_config_pb2 import ModelConfig +import tritonclient.grpc.model_config_pb2 as mc +import uvloop from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import DistributedRuntime, dynamo_worker @@ -23,37 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime): endpoint = component.endpoint("generate") - triton_model_config = ModelConfig() + triton_model_config = mc.ModelConfig() triton_model_config.name = "echo" triton_model_config.platform = "custom" input_tensor = triton_model_config.input.add() input_tensor.name = "input" - input_tensor.data_type = "TYPE_STRING" + input_tensor.data_type = mc.TYPE_STRING input_tensor.dims.extend([-1]) optional_input_tensor = triton_model_config.input.add() optional_input_tensor.name = "optional_input" - optional_input_tensor.data_type = "TYPE_INT32" + optional_input_tensor.data_type = mc.TYPE_INT32 optional_input_tensor.dims.extend([-1]) optional_input_tensor.optional = True output_tensor = triton_model_config.output.add() output_tensor.name = "dummy_output" - output_tensor.data_type = "TYPE_STRING" + output_tensor.data_type = mc.TYPE_STRING output_tensor.dims.extend([-1]) triton_model_config.model_transaction_policy.decoupled = True - # Serialize and base64-encode the Triton model config - b64_config = b64encode(triton_model_config.SerializeToString()) - model_config = { "name": "", "inputs": [], "outputs": [], - "triton_model_config": b64_config, + "triton_model_config": triton_model_config.SerializeToString(), } runtime_config = ModelRuntimeConfig() runtime_config.set_tensor_model_config(model_config) - assert model_config == runtime_config.get_tensor_model_config() + # Internally the bytes string will be converted to List of int + retrieved_model_config = runtime_config.get_tensor_model_config() + retrieved_model_config["triton_model_config"] = bytes( + retrieved_model_config["triton_model_config"] + ) + assert model_config == retrieved_model_config # [gluo FIXME] register_llm will attempt to load a LLM model, # which is not well-defined for Tensor yet. Currently provide From 3d54ce76717bea7644b112fb48d62add1553b43a Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 29 Oct 2025 01:55:28 +0800 Subject: [PATCH 7/8] chore: address comment Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- lib/llm/src/grpc/service/kserve.rs | 183 +++++++++++++++-------------- lib/llm/src/protocols/tensor.rs | 13 +- lib/llm/tests/kserve_service.rs | 10 +- 3 files changed, 102 insertions(+), 104 deletions(-) diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 7ba2f41d91..c90bf97e0f 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -11,6 +11,8 @@ use crate::http::service::Metrics; use crate::http::service::metrics; use crate::discovery::ModelManager; +use crate::local_model::runtime_config::ModelRuntimeConfig; +use crate::protocols::tensor::TensorModelConfig; use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse}; use crate::request_template::RequestTemplate; use anyhow::Result; @@ -185,6 +187,27 @@ impl KserveServiceConfigBuilder { } } +#[allow(clippy::large_enum_variant)] +enum Config { + Dynamo(TensorModelConfig), + Triton(ModelConfig), +} + +impl Config { + fn from_runtime_config(runtime_config: &ModelRuntimeConfig) -> Result { + if let Some(tensor_model_config) = runtime_config.tensor_model_config.as_ref() { + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let model_config = ModelConfig::decode(triton_model_config.as_slice())?; + Ok(Config::Triton(model_config)) + } else { + Ok(Config::Dynamo(tensor_model_config.clone())) + } + } else { + Err(anyhow::anyhow!("no model config is provided")) + } + } +} + #[tonic::async_trait] impl GrpcInferenceService for KserveService { async fn model_infer( @@ -418,18 +441,14 @@ impl GrpcInferenceService for KserveService { .find(|card| request_model_name == &card.display_name) { if card.model_type.supports_tensor() { - if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() - { - if let Some(triton_model_config) = - tensor_model_config.triton_model_config.as_ref() - { - let model_config = ModelConfig::decode(triton_model_config.as_slice()) - .map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; + let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| { + Status::invalid_argument(format!( + "Model '{}' has type Tensor but: {}", + request_model_name, e + )) + })?; + match config { + Config::Triton(model_config) => { return Ok(Response::new(ModelMetadataResponse { name: model_config.name, versions: vec!["1".to_string()], @@ -464,36 +483,34 @@ impl GrpcInferenceService for KserveService { .collect(), })); } - return Ok(Response::new(ModelMetadataResponse { - name: tensor_model_config.name.clone(), - versions: vec!["1".to_string()], - platform: "dynamo".to_string(), - inputs: tensor_model_config - .inputs - .iter() - .map(|input| inference::model_metadata_response::TensorMetadata { - name: input.name.clone(), - datatype: input.data_type.to_string(), - shape: input.shape.clone(), - }) - .collect(), - outputs: tensor_model_config - .outputs - .iter() - .map( - |output| inference::model_metadata_response::TensorMetadata { - name: output.name.clone(), - datatype: output.data_type.to_string(), - shape: output.shape.clone(), - }, - ) - .collect(), - })); + Config::Dynamo(model_config) => { + return Ok(Response::new(ModelMetadataResponse { + name: model_config.name.clone(), + versions: vec!["1".to_string()], + platform: "dynamo".to_string(), + inputs: model_config + .inputs + .iter() + .map(|input| inference::model_metadata_response::TensorMetadata { + name: input.name.clone(), + datatype: input.data_type.to_string(), + shape: input.shape.clone(), + }) + .collect(), + outputs: model_config + .outputs + .iter() + .map( + |output| inference::model_metadata_response::TensorMetadata { + name: output.name.clone(), + datatype: output.data_type.to_string(), + shape: output.shape.clone(), + }, + ) + .collect(), + })); + } } - Err(Status::invalid_argument(format!( - "Model '{}' has type Tensor but no model config is provided", - request_model_name - )))? } else if card.model_type.supports_completions() { return Ok(Response::new(ModelMetadataResponse { name: card.display_name, @@ -543,56 +560,50 @@ impl GrpcInferenceService for KserveService { .find(|card| request_model_name == &card.display_name) { if card.model_type.supports_tensor() { - if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() - { - if let Some(triton_model_config) = - tensor_model_config.triton_model_config.as_ref() - { - let model_config = ModelConfig::decode(triton_model_config.as_slice()) - .map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; + let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| { + Status::invalid_argument(format!( + "Model '{}' has type Tensor but: {}", + request_model_name, e + )) + })?; + match config { + Config::Triton(model_config) => { return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); } - let model_config = ModelConfig { - name: tensor_model_config.name.clone(), - platform: "dynamo".to_string(), - backend: "dynamo".to_string(), - input: tensor_model_config - .inputs - .iter() - .map(|input| ModelInput { - name: input.name.clone(), - data_type: input.data_type.to_kserve(), - dims: input.shape.clone(), - ..Default::default() - }) - .collect(), - output: tensor_model_config - .outputs - .iter() - .map(|output| ModelOutput { - name: output.name.clone(), - data_type: output.data_type.to_kserve(), - dims: output.shape.clone(), - ..Default::default() - }) - .collect(), - ..Default::default() - }; - return Ok(Response::new(ModelConfigResponse { - config: Some(model_config.clone()), - })); + Config::Dynamo(tensor_model_config) => { + let model_config = ModelConfig { + name: tensor_model_config.name.clone(), + platform: "dynamo".to_string(), + backend: "dynamo".to_string(), + input: tensor_model_config + .inputs + .iter() + .map(|input| ModelInput { + name: input.name.clone(), + data_type: input.data_type.to_kserve(), + dims: input.shape.clone(), + ..Default::default() + }) + .collect(), + output: tensor_model_config + .outputs + .iter() + .map(|output| ModelOutput { + name: output.name.clone(), + data_type: output.data_type.to_kserve(), + dims: output.shape.clone(), + ..Default::default() + }) + .collect(), + ..Default::default() + }; + return Ok(Response::new(ModelConfigResponse { + config: Some(model_config.clone()), + })); + } } - Err(Status::invalid_argument(format!( - "Model '{}' has type Tensor but no model config is provided", - request_model_name - )))? } else if card.model_type.supports_completions() { let config = ModelConfig { name: card.display_name, diff --git a/lib/llm/src/protocols/tensor.rs b/lib/llm/src/protocols/tensor.rs index 16f9f85b0c..9bc0668895 100644 --- a/lib/llm/src/protocols/tensor.rs +++ b/lib/llm/src/protocols/tensor.rs @@ -124,7 +124,7 @@ pub struct TensorMetadata { pub parameters: Parameters, } -#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)] pub struct TensorModelConfig { pub name: String, pub inputs: Vec, @@ -135,17 +135,6 @@ pub struct TensorModelConfig { pub triton_model_config: Option>, } -impl Default for TensorModelConfig { - fn default() -> Self { - Self { - name: "".to_string(), - inputs: vec![], - outputs: vec![], - triton_model_config: None, - } - } -} - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Tensor { pub metadata: TensorMetadata, diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index afb3301360..e002eb37e7 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -1350,9 +1350,8 @@ pub mod kserve_test { err ); assert!( - err.message() - .contains("has type Tensor but no model config is provided"), - "Expected error message to contain 'has type Tensor but no model config is provided', got: {}", + err.message().contains("no model config is provided"), + "Expected error message to contain 'no model config is provided', got: {}", err.message() ); @@ -1371,9 +1370,8 @@ pub mod kserve_test { err ); assert!( - err.message() - .contains("has type Tensor but no model config is provided"), - "Expected error message to contain 'has type Tensor but no model config is provided', got: {}", + err.message().contains("no model config is provided"), + "Expected error message to contain 'no model config is provided', got: {}", err.message() ); From 06677f420de2e51ae209da698ab01fd39fc91f7a Mon Sep 17 00:00:00 2001 From: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Date: Wed, 29 Oct 2025 17:46:30 +0800 Subject: [PATCH 8/8] chore: address comment Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> --- lib/llm/tests/kserve_service.rs | 97 ++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index e002eb37e7..d968365b7d 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -1264,7 +1264,7 @@ pub mod kserve_test { card.model_input = ModelInput::Tensor; card.runtime_config = ModelRuntimeConfig { tensor_model_config: Some(tensor::TensorModelConfig { - triton_model_config: Some(buf), + triton_model_config: Some(buf.clone()), ..Default::default() }), ..Default::default() @@ -1299,6 +1299,101 @@ pub mod kserve_test { config, expected_model_config, "Expected same model config to be returned", ); + + // Pass config with both TensorModelConfig and triton_model_config, + // check if the Triton model config is used. + let _ = service_with_engines + .0 + .model_manager() + .remove_model_card("key"); + let mut card = ModelDeploymentCard::with_name_only(model_name); + card.model_type = ModelType::TensorBased; + card.model_input = ModelInput::Tensor; + let mut card = ModelDeploymentCard::with_name_only("tensor"); + card.model_type = ModelType::TensorBased; + card.model_input = ModelInput::Tensor; + card.runtime_config = ModelRuntimeConfig { + tensor_model_config: Some(tensor::TensorModelConfig { + name: "tensor".to_string(), + inputs: vec![tensor::TensorMetadata { + name: "input".to_string(), + data_type: tensor::DataType::Int32, + shape: vec![1], + parameters: Default::default(), + }], + outputs: vec![tensor::TensorMetadata { + name: "output".to_string(), + data_type: tensor::DataType::Bool, + shape: vec![-1], + parameters: Default::default(), + }], + triton_model_config: Some(buf.clone()), + }), + ..Default::default() + }; + let _ = service_with_engines + .0 + .model_manager() + .save_model_card("key", card); + let request = tonic::Request::new(ModelConfigRequest { + name: model_name.into(), + version: "".into(), + }); + + let response = client + .model_config(request) + .await + .unwrap() + .into_inner() + .config; + let Some(config) = response else { + panic!("Expected Some(config), got None"); + }; + assert_eq!( + config, expected_model_config, + "Expected same model config to be returned", + ); + + // Test invalid triton model config + let _ = service_with_engines + .0 + .model_manager() + .remove_model_card("key"); + let mut card = ModelDeploymentCard::with_name_only(model_name); + card.model_type = ModelType::TensorBased; + card.model_input = ModelInput::Tensor; + card.runtime_config = ModelRuntimeConfig { + tensor_model_config: Some(tensor::TensorModelConfig { + triton_model_config: Some(vec![1, 2, 3, 4, 5]), + ..Default::default() + }), + ..Default::default() + }; + let _ = service_with_engines + .0 + .model_manager() + .save_model_card("key", card); + + // success config + let request = tonic::Request::new(ModelConfigRequest { + name: model_name.into(), + version: "".into(), + }); + + let response = client.model_config(request).await; + assert!(response.is_err()); + let err = response.unwrap_err(); + assert_eq!( + err.code(), + tonic::Code::InvalidArgument, + "Expected InvalidArgument error, get {}", + err + ); + assert!( + err.message().contains("failed to decode Protobuf message"), + "Expected error message to contain 'failed to decode Protobuf message', got: {}", + err.message() + ); } #[rstest]