Skip to content
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ utoipa-swagger-ui = { version = "9.0", features = ["axum"] }
tonic = { version = "0.13.1" }
# Request prost specifically so tonic-build properly compiles protobuf message
prost = { version = "0.13.5" }
base64 = "0.22.1"

# tokenizers
tokenizers = { version = "0.21.4", default-features = false, features = [
Expand Down
4 changes: 3 additions & 1 deletion lib/llm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}

fn build_protos() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/grpc/protos/kserve.proto")?;
tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]")
.compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;
Ok(())
}

Expand Down
35 changes: 35 additions & 0 deletions lib/llm/src/grpc/service/kserve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ use inference::{
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
};

use base64::engine::{Engine, general_purpose};
use prost::Message;

/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
/// for it as part of HTTP service. Should we always start HTTP service up
/// for non-inference?
Expand Down Expand Up @@ -418,6 +421,31 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() {
let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?;
let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?;
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name,
versions: vec!["1".to_string()],
platform: model_config.platform,
inputs: model_config.input.iter().map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: match inference::DataType::try_from(input.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: input.dims.clone(),
}).collect(),
outputs: model_config.output.iter().map(|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: match inference::DataType::try_from(output.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: output.dims.clone(),
}).collect(),
}));
}
return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(),
versions: vec!["1".to_string()],
Expand Down Expand Up @@ -499,6 +527,13 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() {
let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?;
let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?;
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config),
}));
}
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
Expand Down
16 changes: 16 additions & 0 deletions lib/llm/src/protocols/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
// Optional Triton model config in serialized protobuf string,
// if provided, it supercedes the basic model config defined above.
// The string is base64-encoded to ensure safe transport over text-based protocols.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<String>,
}

impl Default for TensorModelConfig {
fn default() -> Self {
Self {
name: "".to_string(),
inputs: vec![],
outputs: vec![],
triton_model_config: None,
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
98 changes: 98 additions & 0 deletions lib/llm/tests/kserve_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub mod kserve_test {
use tonic::{Request, Response, transport::Channel};

use dynamo_async_openai::types::Prompt;
use prost::Message;
use base64::engine::{Engine, general_purpose};

struct SplitEngine {}

Expand Down Expand Up @@ -359,6 +361,7 @@ pub mod kserve_test {
ModelInfo = 8994,
TensorModel = 8995,
TensorModelTypes = 8996,
TritonModelConfig = 8997,
}

#[rstest]
Expand Down Expand Up @@ -1169,6 +1172,7 @@ pub mod kserve_test {
data_type: tensor::DataType::Bool,
shape: vec![-1],
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down Expand Up @@ -1202,6 +1206,99 @@ pub mod kserve_test {
);
}

#[rstest]
#[tokio::test]
async fn test_triton_model_config(
#[with(TestPort::TritonModelConfig as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
) {
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());

let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await;

let model_name = "tensor";
let expected_model_config = inference::ModelConfig {
name: model_name.to_string(),
platform: "custom".to_string(),
backend: "custom".to_string(),
input: vec![inference::ModelInput {
name: "input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: false,
..Default::default()
},
inference::ModelInput {
name: "optional_input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: true,
..Default::default()
}],
output: vec![inference::ModelOutput {
name: "output".to_string(),
data_type: DataType::TypeBool as i32,
dims: vec![-1],
..Default::default()
}],
model_transaction_policy: Some(inference::ModelTransactionPolicy {
decoupled: true,
}),
..Default::default()
};

let mut buf = vec![];
expected_model_config.encode(&mut buf).unwrap();
let buf = general_purpose::STANDARD.encode(&buf);

// Register a tensor model
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(buf),
..Default::default()
}),
..Default::default()
};
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);

// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});

let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
}

#[rstest]
#[tokio::test]
async fn test_tensor_infer(
Expand Down Expand Up @@ -1299,6 +1396,7 @@ pub mod kserve_test {
data_type: tensor::DataType::Bool,
shape: vec![-1],
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down
39 changes: 34 additions & 5 deletions tests/frontend/grpc/echo_tensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.


from base64 import b64encode

import uvloop

# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
from tritonclient.grpc.model_config_pb2 import ModelConfig

from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker

Expand All @@ -17,12 +23,32 @@ async def echo_tensor_worker(runtime: DistributedRuntime):

endpoint = component.endpoint("generate")

triton_model_config = ModelConfig()
triton_model_config.name = "echo"
triton_model_config.platform = "custom"
input_tensor = triton_model_config.input.add()
input_tensor.name = "input"
input_tensor.data_type = "TYPE_STRING"
input_tensor.dims.extend([-1])
optional_input_tensor = triton_model_config.input.add()
optional_input_tensor.name = "optional_input"
optional_input_tensor.data_type = "TYPE_INT32"
optional_input_tensor.dims.extend([-1])
optional_input_tensor.optional = True
output_tensor = triton_model_config.output.add()
output_tensor.name = "dummy_output"
output_tensor.data_type = "TYPE_STRING"
output_tensor.dims.extend([-1])
triton_model_config.model_transaction_policy.decoupled = True

# Serialize and base64-encode the Triton model config
b64_config = b64encode(triton_model_config.SerializeToString())

model_config = {
"name": "echo",
"inputs": [
{"name": "dummy_input", "data_type": "Bytes", "shape": [-1]},
],
"outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}],
"name": "",
"inputs": [],
"outputs": [],
"triton_model_config": b64_config,
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)
Expand All @@ -45,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime):


async def generate(request, context):
# [NOTE] gluo: currently there is no frontend side
# validation between model config and actual request,
# so any request will reach here and be echoed back.
print(f"Echoing request: {request}")
yield {"model": request["model"], "tensors": request["tensors"]}

Expand Down
1 change: 1 addition & 0 deletions tests/frontend/grpc/test_tensor_mocker_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ def start_services(request, runtime_services):
@pytest.mark.model(TEST_MODEL)
def test_echo() -> None:
triton_echo_client.run_infer()
triton_echo_client.get_config()
14 changes: 14 additions & 0 deletions tests/frontend/grpc/triton_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,17 @@ def run_infer():

assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)


def get_config():
server_url = "localhost:8000"
try:
triton_client = grpcclient.InferenceServerClient(url=server_url)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()

model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
Loading