Skip to content
4 changes: 3 additions & 1 deletion lib/llm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}

fn build_protos() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/grpc/protos/kserve.proto")?;
tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]")
.compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;
Ok(())
}

Expand Down
60 changes: 60 additions & 0 deletions lib/llm/src/grpc/service/kserve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use inference::{
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
};

use prost::Message;

/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
/// for it as part of HTTP service. Should we always start HTTP service up
/// for non-inference?
Expand Down Expand Up @@ -418,6 +420,50 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) =
tensor_model_config.triton_model_config.as_ref()
{
let model_config = ModelConfig::decode(triton_model_config.as_slice())
.map_err(|e| {
Status::invalid_argument(format!(
"Failed to deserialize model config: {}",
e
))
})?;
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name,
versions: vec!["1".to_string()],
platform: model_config.platform,
inputs: model_config
.input
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: match inference::DataType::try_from(input.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: input.dims.clone(),
})
.collect(),
outputs: model_config
.output
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: match inference::DataType::try_from(
output.data_type,
) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: output.dims.clone(),
},
)
.collect(),
}));
}
return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(),
versions: vec!["1".to_string()],
Expand Down Expand Up @@ -499,6 +545,20 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) =
tensor_model_config.triton_model_config.as_ref()
{
let model_config = ModelConfig::decode(triton_model_config.as_slice())
.map_err(|e| {
Status::invalid_argument(format!(
"Failed to deserialize model config: {}",
e
))
})?;
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config),
}));
}
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
Expand Down
15 changes: 15 additions & 0 deletions lib/llm/src/protocols/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
// Optional Triton model config in serialized protobuf string,
// if provided, it supersedes the basic model config defined above.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<Vec<u8>>,
}

impl Default for TensorModelConfig {
fn default() -> Self {
Self {
name: "".to_string(),
inputs: vec![],
outputs: vec![],
triton_model_config: None,
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
96 changes: 96 additions & 0 deletions lib/llm/tests/kserve_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod kserve_test {
use tonic::{Request, Response, transport::Channel};

use dynamo_async_openai::types::Prompt;
use prost::Message;

struct SplitEngine {}

Expand Down Expand Up @@ -361,6 +362,7 @@ pub mod kserve_test {
ModelInfo = 8994,
TensorModel = 8995,
TensorModelTypes = 8996,
TritonModelConfig = 8997,
}

#[rstest]
Expand Down Expand Up @@ -1173,6 +1175,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down Expand Up @@ -1206,6 +1209,98 @@ pub mod kserve_test {
);
}

#[rstest]
#[tokio::test]
async fn test_triton_model_config(
#[with(TestPort::TritonModelConfig as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
) {
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());

let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await;

let model_name = "tensor";
let expected_model_config = inference::ModelConfig {
name: model_name.to_string(),
platform: "custom".to_string(),
backend: "custom".to_string(),
input: vec![
inference::ModelInput {
name: "input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: false,
..Default::default()
},
inference::ModelInput {
name: "optional_input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: true,
..Default::default()
},
],
output: vec![inference::ModelOutput {
name: "output".to_string(),
data_type: DataType::TypeBool as i32,
dims: vec![-1],
..Default::default()
}],
model_transaction_policy: Some(inference::ModelTransactionPolicy { decoupled: true }),
..Default::default()
};

let mut buf = vec![];
expected_model_config.encode(&mut buf).unwrap();

// Register a tensor model
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(buf),
..Default::default()
}),
..Default::default()
};
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);

// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});

let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
}

#[rstest]
#[tokio::test]
async fn test_tensor_infer(
Expand Down Expand Up @@ -1305,6 +1400,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down
40 changes: 34 additions & 6 deletions tests/frontend/grpc/echo_tensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.


# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
import uvloop

from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
Expand All @@ -17,17 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime):

endpoint = component.endpoint("generate")

triton_model_config = mc.ModelConfig()
triton_model_config.name = "echo"
triton_model_config.platform = "custom"
input_tensor = triton_model_config.input.add()
input_tensor.name = "input"
input_tensor.data_type = mc.TYPE_STRING
input_tensor.dims.extend([-1])
optional_input_tensor = triton_model_config.input.add()
optional_input_tensor.name = "optional_input"
optional_input_tensor.data_type = mc.TYPE_INT32
optional_input_tensor.dims.extend([-1])
optional_input_tensor.optional = True
output_tensor = triton_model_config.output.add()
output_tensor.name = "dummy_output"
output_tensor.data_type = mc.TYPE_STRING
output_tensor.dims.extend([-1])
triton_model_config.model_transaction_policy.decoupled = True

model_config = {
"name": "echo",
"inputs": [
{"name": "dummy_input", "data_type": "Bytes", "shape": [-1]},
],
"outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}],
"name": "",
"inputs": [],
"outputs": [],
"triton_model_config": triton_model_config.SerializeToString(),
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)

assert model_config == runtime_config.get_tensor_model_config()
# Internally the bytes string will be converted to List of int
retrieved_model_config = runtime_config.get_tensor_model_config()
retrieved_model_config["triton_model_config"] = bytes(
retrieved_model_config["triton_model_config"]
)
assert model_config == retrieved_model_config

# [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide
Expand All @@ -46,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime):

async def generate(request, context):
"""Echo tensors and parameters back to the client."""
# [NOTE] gluo: currently there is no frontend side
# validation between model config and actual request,
# so any request will reach here and be echoed back.
print(f"Echoing request: {request}")

params = {}
Expand Down
1 change: 1 addition & 0 deletions tests/frontend/grpc/test_tensor_mocker_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ def start_services(request, runtime_services):
@pytest.mark.model(TEST_MODEL)
def test_echo() -> None:
triton_echo_client.run_infer()
triton_echo_client.get_config()
14 changes: 14 additions & 0 deletions tests/frontend/grpc/triton_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,17 @@ def run_infer():

assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)


def get_config():
server_url = "localhost:8000"
try:
triton_client = grpcclient.InferenceServerClient(url=server_url)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()

model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
Loading