diff --git a/examples/log_parsing/README.md b/examples/log_parsing/README.md index 55dbffa0bb..b4f882d763 100644 --- a/examples/log_parsing/README.md +++ b/examples/log_parsing/README.md @@ -110,7 +110,7 @@ PYTHONPATH="examples/log_parsing" \ morpheus --log_level INFO \ --plugin "inference" \ --plugin "postprocessing" \ - run --num_threads 1 --use_cpp False --pipeline_batch_size 1024 --model_max_batch_size 32 \ + run --num_threads 1 --pipeline_batch_size 1024 --model_max_batch_size 32 \ pipeline-nlp \ from-file --filename ./models/datasets/validation-data/log-parsing-validation-data-input.csv \ deserialize \ diff --git a/examples/log_parsing/inference.py b/examples/log_parsing/inference.py index f298cbce64..7e6234de82 100644 --- a/examples/log_parsing/inference.py +++ b/examples/log_parsing/inference.py @@ -13,33 +13,28 @@ # limitations under the License. import logging -import typing -from functools import partial import cupy as cp -import mrc import numpy as np import tritonclient.grpc as tritonclient -from mrc.core import operators as ops from scipy.special import softmax -from messages import MultiPostprocLogParsingMessage # pylint: disable=no-name-in-module -from messages import PostprocMemoryLogParsing # pylint: disable=no-name-in-module -from messages import ResponseMemoryLogParsing # pylint: disable=no-name-in-module from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes from morpheus.messages import MultiInferenceMessage +from morpheus.messages import MultiInferenceNLPMessage +from morpheus.messages import MultiResponseMessage +from morpheus.messages import TensorMemory from morpheus.pipeline.stage_schema import StageSchema -from morpheus.stages.inference.inference_stage import InferenceStage -from morpheus.stages.inference.inference_stage import InferenceWorker -from morpheus.stages.inference.triton_inference_stage import _TritonInferenceWorker +from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage +from morpheus.stages.inference.triton_inference_stage import TritonInferenceWorker from morpheus.utils.producer_consumer_queue import ProducerConsumerQueue logger = logging.getLogger(__name__) -class TritonInferenceLogParsing(_TritonInferenceWorker): +class TritonInferenceLogParsing(TritonInferenceWorker): """ This class extends TritonInference to deal with scenario-specific NLP models inference requests like building response. @@ -59,78 +54,47 @@ class TritonInferenceLogParsing(_TritonInferenceWorker): use_shared_memory: bool, default = True Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine - + needs_logits : bool, default = True + Determines whether a logits calculation is needed for the value returned by the Triton inference response. """ - def __init__(self, - inf_queue: ProducerConsumerQueue, - c: Config, - model_name: str, - server_url: str, - force_convert_inputs: bool, - use_shared_memory: bool, - inout_mapping: typing.Dict[str, str] = None): - # Some models use different names for the same thing. Set that here but allow user customization - default_mapping = { - "attention_mask": "input_mask", - } - - default_mapping.update(inout_mapping if inout_mapping is not None else {}) - - super().__init__(inf_queue, - c, - model_name=model_name, - server_url=server_url, - force_convert_inputs=force_convert_inputs, - use_shared_memory=use_shared_memory, - inout_mapping=default_mapping) - - @classmethod - def needs_logits(cls): - return True - - @classmethod - def default_inout_mapping(cls) -> typing.Dict[str, str]: - # Some models use different names for the same thing. Set that here but allow user customization - return {"attention_mask": "input_mask"} + def build_output_message(self, x: MultiInferenceMessage) -> MultiResponseMessage: + seq_ids = cp.zeros((x.count, 3), dtype=cp.uint32) + seq_ids[:, 0] = cp.arange(x.mess_offset, x.mess_offset + x.count, dtype=cp.uint32) + seq_ids[:, 2] = x.get_tensor('seq_ids')[:, 2] - def build_output_message(self, x: MultiInferenceMessage) -> MultiPostprocLogParsingMessage: - - memory = PostprocMemoryLogParsing( + memory = TensorMemory( count=x.count, - confidences=cp.zeros((x.count, self._inputs[list(self._inputs.keys())[0]].shape[1])), - labels=cp.zeros((x.count, self._inputs[list(self._inputs.keys())[0]].shape[1])), - input_ids=cp.zeros((x.count, x.input_ids.shape[1])), - seq_ids=cp.zeros((x.count, x.seq_ids.shape[1])), - ) - - output_message = MultiPostprocLogParsingMessage(meta=x.meta, - mess_offset=x.mess_offset, - mess_count=x.mess_count, - memory=memory, - offset=0, - count=x.count) - return output_message - - def _build_response(self, batch: MultiInferenceMessage, - result: tritonclient.InferResult) -> ResponseMemoryLogParsing: - - output = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} - output = {key: softmax(val, axis=2) for key, val in output.items()} - confidences = {key: np.amax(val, axis=2) for key, val in output.items()} - labels = {key: np.argmax(val, axis=2) for key, val in output.items()} - - mem = ResponseMemoryLogParsing( - count=output[list(output.keys())[0]].shape[0], - confidences=cp.array(confidences[list(output.keys())[0]]), - labels=cp.array(labels[list(output.keys())[0]]), - ) - - return mem + tensors={ + 'confidences': cp.zeros((x.count, self._inputs[list(self._inputs.keys())[0]].shape[1])), + 'labels': cp.zeros((x.count, self._inputs[list(self._inputs.keys())[0]].shape[1])), + 'input_ids': cp.zeros((x.count, x.get_tensor('input_ids').shape[1])), + 'seq_ids': seq_ids + }) + + return MultiResponseMessage(meta=x.meta, + mess_offset=x.mess_offset, + mess_count=x.mess_count, + memory=memory, + offset=0, + count=x.count) + + def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> TensorMemory: + + outputs = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} + outputs = {key: softmax(val, axis=2) for key, val in outputs.items()} + confidences = {key: np.amax(val, axis=2) for key, val in outputs.items()} + labels = {key: np.argmax(val, axis=2) for key, val in outputs.items()} + + return TensorMemory(count=outputs[list(outputs.keys())[0]].shape[0], + tensors={ + 'confidences': cp.array(confidences[list(outputs.keys())[0]]), + 'labels': cp.array(labels[list(outputs.keys())[0]]) + }) @register_stage("inf-logparsing", modes=[PipelineModes.NLP]) -class LogParsingInferenceStage(InferenceStage): +class LogParsingInferenceStage(TritonInferenceStage): """ NLP Triton inference stage for log parsing pipeline. @@ -149,7 +113,11 @@ class LogParsingInferenceStage(InferenceStage): use_shared_memory: bool, default = False, is_flag = True Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine - + needs_logits : bool, default = True, is_flag = True + Determines whether a logits calculation is needed for the value returned by the Triton inference response. + inout_mapping : dict[str, str], optional + Dictionary used to map pipeline input/output names to Triton input/output names. Use this if the + Morpheus names do not match the model. """ def __init__(self, @@ -157,109 +125,59 @@ def __init__(self, model_name: str, server_url: str, force_convert_inputs: bool = False, - use_shared_memory: bool = False): - super().__init__(c) - - self._config = c - - self._kwargs = { - "model_name": model_name, - "server_url": server_url, - "force_convert_inputs": force_convert_inputs, - "use_shared_memory": use_shared_memory, - } - - self._requires_seg_ids = False + use_shared_memory: bool = False, + needs_logits: bool = True, + inout_mapping: dict[str, str] = None): + super().__init__(c, + model_name=model_name, + server_url=server_url, + force_convert_inputs=force_convert_inputs, + use_shared_memory=use_shared_memory, + needs_logits=needs_logits, + inout_mapping=inout_mapping) - def supports_cpp_node(self): - # Get the value from the worker class + def supports_cpp_node(self) -> bool: return False def compute_schema(self, schema: StageSchema): - schema.output_schema.set_type(MultiPostprocLogParsingMessage) - - def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - - def py_inference_fn(obs: mrc.Observable, sub: mrc.Subscriber): - - worker = self._get_inference_worker(self._inf_queue) - - worker.init() - - outstanding_requests = 0 - - def on_next(x: MultiInferenceMessage): - nonlocal outstanding_requests - - batches = self._split_batches(x, self._max_batch_size) - - output_message = worker.build_output_message(x) - - memory = output_message.memory - - fut_list = [] - - for batch in batches: - outstanding_requests += 1 + schema.output_schema.set_type(MultiResponseMessage) - fut = mrc.Future() - - def set_output_fut(resp: ResponseMemoryLogParsing, inner_b, inner_f: mrc.Future): - nonlocal outstanding_requests - inner_memory = self._convert_one_response(memory, inner_b, resp) - - inner_f.set_result(inner_memory) - - outstanding_requests -= 1 - - fut_list.append(fut) - - worker.process(batch, partial(set_output_fut, inner_b=batch, inner_f=fut)) - - for f in fut_list: - f.result() - - return output_message - - obs.pipe(ops.map(on_next)).subscribe(sub) - - assert outstanding_requests == 0, "Not all inference requests were completed" + @staticmethod + def _convert_one_response(output: MultiResponseMessage, inf: MultiInferenceNLPMessage, + res: TensorMemory) -> MultiResponseMessage: + memory = output.memory - if (self._build_cpp_node()): - node = self._get_cpp_inference_node(builder) - else: - node = builder.make_node(self.unique_name, ops.build(py_inference_fn)) + out_seq_ids = memory.get_tensor('seq_ids') + input_ids = memory.get_tensor('input_ids') + confidences = memory.get_tensor('confidences') + labels = memory.get_tensor('labels') - # Set the concurrency level to be up with the thread count - node.launch_options.pe_count = self._thread_count - builder.make_edge(input_node, node) + seq_ids = inf.get_id_tensor() - return node + seq_offset = seq_ids[0, 0].item() - output.mess_offset + seq_count = (seq_ids[-1, 0].item() + 1 - seq_offset) - output.mess_offset - @staticmethod - def _convert_one_response(output: PostprocMemoryLogParsing, - inf: MultiInferenceMessage, - res: ResponseMemoryLogParsing): + input_ids[inf.offset:inf.count + inf.offset, :] = inf.get_tensor('input_ids') + out_seq_ids[inf.offset:inf.count + inf.offset, :] = seq_ids - output.input_ids[inf.offset:inf.count + inf.offset, :] = inf.input_ids - output.seq_ids[inf.offset:inf.count + inf.offset, :] = inf.seq_ids + resp_confidences = res.get_tensor('confidences') + resp_labels = res.get_tensor('labels') # Two scenarios: if (inf.mess_count == inf.count): - output.confidences[inf.offset:inf.count + inf.offset, :] = res.confidences - output.labels[inf.offset:inf.count + inf.offset, :] = res.labels + assert seq_count == res.count + confidences[inf.offset:inf.offset + inf.count, :] = resp_confidences + labels[inf.offset:inf.offset + inf.count, :] = resp_labels else: assert inf.count == res.count - mess_ids = inf.seq_ids[:, 0].get().tolist() + mess_ids = seq_ids[:, 0].get().tolist() - # Out message has more reponses, so we have to do key based blending of probs for i, idx in enumerate(mess_ids): - output.confidences[idx, :] = cp.maximum(output.confidences[idx, :], res.confidences[i, :]) - output.labels[idx, :] = cp.maximum(output.labels[idx, :], res.labels[i, :]) - - return MultiPostprocLogParsingMessage.from_message(inf, memory=output, offset=inf.offset, count=inf.mess_count) + confidences[idx, :] = cp.maximum(confidences[idx, :], resp_confidences[i, :]) + labels[idx, :] = cp.maximum(labels[idx, :], resp_labels[i, :]) - def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> InferenceWorker: + return MultiResponseMessage.from_message(inf, memory=memory, offset=inf.offset, count=inf.mess_count) - return TritonInferenceLogParsing(inf_queue, self._config, **self._kwargs) + def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> TritonInferenceLogParsing: + return TritonInferenceLogParsing(inf_queue=inf_queue, c=self._config, **self._kwargs) diff --git a/examples/log_parsing/messages.py b/examples/log_parsing/messages.py deleted file mode 100644 index f39c9093db..0000000000 --- a/examples/log_parsing/messages.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses - -import cupy as cp - -from morpheus.messages import InferenceMemory -from morpheus.messages import MultiInferenceMessage -from morpheus.messages import MultiResponseMessage -from morpheus.messages import ResponseMemory -from morpheus.messages.data_class_prop import DataClassProp - - -@dataclasses.dataclass(init=False) -class ResponseMemoryLogParsing(ResponseMemory, cpp_class=None): - - confidences: dataclasses.InitVar[cp.ndarray] = DataClassProp(ResponseMemory._get_tensor_prop, - ResponseMemory.set_output) - labels: dataclasses.InitVar[cp.ndarray] = DataClassProp(ResponseMemory._get_tensor_prop, ResponseMemory.set_output) - - def __init__(self, *, count: int, confidences: cp.ndarray, labels: cp.ndarray): - super().__init__(count=count, tensors={'confidences': confidences, 'labels': labels}) - - -@dataclasses.dataclass -class MultiResponseLogParsingMessage(MultiResponseMessage, cpp_class=None): - """ - A stronger typed version of `MultiResponseMessage` that is used for inference workloads that return a probability - array. Helps ensure the proper outputs are set and eases debugging. - """ - - @property - def confidences(self): - """ - Returns token-ids for each string padded with 0s to max_length. - - Returns - ------- - cupy.ndarray - The token-ids for each string padded with 0s to max_length. - - """ - - return self.get_output("confidences") - - @property - def labels(self): - """ - Returns sequence ids, which are used to keep track of which inference requests belong to each message. - - Returns - ------- - cupy.ndarray - Ids used to index from an inference input to a message. Necessary since there can be more - inference inputs than messages (i.e. If some messages get broken into multiple inference requests) - - """ - - return self.get_output("labels") - - @property - def input_ids(self): - """ - input_ids - - Returns - ------- - cp.ndarray - input_ids - - """ - - return self.get_output("input_ids") - - @property - def seq_ids(self): - """ - seq_ids - - Returns - ------- - cp.ndarray - seq_ids - - """ - - return self.get_output("seq_ids") - - -@dataclasses.dataclass(init=False) -class PostprocMemoryLogParsing(InferenceMemory): - """ - This is a container class for data that needs to be submitted to the inference server for NLP category - usecases. - - Parameters - ---------- - confidences: cp.ndarray - confidences calculated from softmax - labels: cp.ndarray - index of highest confidence - input_ids : cp.ndarray - The token-ids for each string padded with 0s to max_length. - seq_ids : cp.ndarray - Ids used to index from an inference input to a message. Necessary since there can be more inference - inputs than messages (i.e. If some messages get broken into multiple inference requests) - - """ - - confidences: dataclasses.InitVar[cp.ndarray] = DataClassProp(InferenceMemory._get_tensor_prop, - InferenceMemory.set_input) - labels: dataclasses.InitVar[cp.ndarray] = DataClassProp(InferenceMemory._get_tensor_prop, InferenceMemory.set_input) - input_ids: dataclasses.InitVar[cp.ndarray] = DataClassProp(InferenceMemory._get_tensor_prop, - InferenceMemory.set_input) - seq_ids: dataclasses.InitVar[cp.ndarray] = DataClassProp(InferenceMemory._get_tensor_prop, - InferenceMemory.set_input) - - def __init__(self, - *, - count: int, - confidences: cp.ndarray, - labels: cp.ndarray, - input_ids: cp.ndarray, - seq_ids: cp.ndarray): - super().__init__(count=count, - tensors={ - 'confidences': confidences, 'labels': labels, 'input_ids': input_ids, 'seq_ids': seq_ids - }) - - -@dataclasses.dataclass -class MultiPostprocLogParsingMessage(MultiInferenceMessage): - """ - A stronger typed version of `MultiInferenceMessage` that is used for NLP workloads. Helps ensure the - proper inputs are set and eases debugging. - """ - - @property - def confidences(self): - """ - Returns token-ids for each string padded with 0s to max_length. - - Returns - ------- - cupy.ndarray - The token-ids for each string padded with 0s to max_length. - - """ - - return self.get_input("confidences") - - @property - def labels(self): - """ - Returns sequence ids, which are used to keep track of which inference requests belong to each message. - - Returns - ------- - cupy.ndarray - Ids used to index from an inference input to a message. Necessary since there can be more - inference inputs than messages (i.e. If some messages get broken into multiple inference requests) - - """ - - return self.get_input("labels") - - @property - def input_ids(self): - """ - input_ids - - Returns - ------- - cp.ndarray - input_ids - - """ - - return self.get_input("input_ids") - - @property - def seq_ids(self): - """ - seq_ids - - Returns - ------- - cp.ndarray - seq_ids - - """ - - return self.get_input("seq_ids") diff --git a/examples/log_parsing/postprocessing.py b/examples/log_parsing/postprocessing.py index c36d7c897e..a43419f63d 100644 --- a/examples/log_parsing/postprocessing.py +++ b/examples/log_parsing/postprocessing.py @@ -22,11 +22,13 @@ import pandas as pd from mrc.core import operators as ops -from messages import MultiPostprocLogParsingMessage # pylint: disable=no-name-in-module +import cudf + from morpheus.cli.register_stage import register_stage from morpheus.config import Config from morpheus.config import PipelineModes from morpheus.messages import MessageMeta +from morpheus.messages import MultiResponseMessage from morpheus.pipeline.single_port_stage import SinglePortStage from morpheus.pipeline.stage_schema import StageSchema @@ -73,18 +75,18 @@ def supports_cpp_node(self): return False def accepted_types(self) -> typing.Tuple: - return (MultiPostprocLogParsingMessage, ) + return (MultiResponseMessage, ) def compute_schema(self, schema: StageSchema): schema.output_schema.set_type(MessageMeta) - def _postprocess(self, x: MultiPostprocLogParsingMessage): + def _postprocess(self, x: MultiResponseMessage): - infer_pdf = pd.DataFrame(x.seq_ids.get()).astype(int) + infer_pdf = pd.DataFrame(x.get_tensor('seq_ids').get()).astype(int) infer_pdf.columns = ["doc", "start", "stop"] - infer_pdf["confidences"] = x.confidences.tolist() - infer_pdf["labels"] = x.labels.tolist() - infer_pdf["token_ids"] = x.input_ids.tolist() + infer_pdf["confidences"] = x.get_tensor('confidences').tolist() + infer_pdf["labels"] = x.get_tensor('labels').tolist() + infer_pdf["token_ids"] = x.get_tensor('input_ids').tolist() infer_pdf["confidences"] = infer_pdf.apply(lambda row: row["confidences"][row["start"]:row["stop"]], axis=1) @@ -115,8 +117,7 @@ def _postprocess(self, x: MultiPostprocLogParsingMessage): # decode cleanup parsed_df = self.__decode_cleanup(parsed_df) - - return MessageMeta(df=parsed_df) + return MessageMeta(df=cudf.DataFrame.from_pandas(parsed_df)) def __get_label_dicts(self, row): token_dict = defaultdict(str) diff --git a/examples/log_parsing/run.py b/examples/log_parsing/run.py index 71db12c831..b272a4625b 100644 --- a/examples/log_parsing/run.py +++ b/examples/log_parsing/run.py @@ -19,7 +19,6 @@ from postprocessing import LogParsingPostProcessingStage from morpheus.config import Config -from morpheus.config import CppConfig from morpheus.config import PipelineModes from morpheus.pipeline import LinearPipeline from morpheus.stages.general.monitor_stage import MonitorStage @@ -92,8 +91,6 @@ def run_pipeline( model_config_file, server_url, ): - CppConfig.set_should_use_cpp(False) - config = Config() config.mode = PipelineModes.NLP config.num_threads = num_threads diff --git a/morpheus/_lib/src/io/serializers.cpp b/morpheus/_lib/src/io/serializers.cpp index cf7f1d78cd..b95b32fd4b 100644 --- a/morpheus/_lib/src/io/serializers.cpp +++ b/morpheus/_lib/src/io/serializers.cpp @@ -158,7 +158,11 @@ void table_to_json(const TableInfoData& tbl, std::ostream& out_stream, bool incl OStreamSink sink(out_stream); auto destination = cudf::io::sink_info(&sink); - auto options_builder = cudf::io::json_writer_options_builder(destination, tbl_view).metadata(tbl_meta).lines(true); + auto options_builder = cudf::io::json_writer_options_builder(destination, tbl_view) + .metadata(tbl_meta) + .lines(true) + .include_nulls(true) + .na_rep("null"); cudf::io::write_json(options_builder.build(), rmm::mr::get_current_device_resource()); diff --git a/morpheus/cli/commands.py b/morpheus/cli/commands.py index 21c134657c..30392e0dbd 100644 --- a/morpheus/cli/commands.py +++ b/morpheus/cli/commands.py @@ -684,7 +684,7 @@ def post_pipeline(ctx: click.Context, *args, **kwargs): "morpheus.stages.inference.auto_encoder_inference_stage.AutoEncoderInferenceStage", modes=AE_ONLY) add_command("inf-pytorch", "morpheus.stages.inference.pytorch_inference_stage.PyTorchInferenceStage", modes=NOT_AE) -add_command("inf-triton", "morpheus.stages.inference.triton_inference_stage.TritonInferenceStage", modes=ALL) +add_command("inf-triton", "morpheus.stages.inference.triton_inference_stage.TritonInferenceStage", modes=NOT_AE) add_command("mlflow-drift", "morpheus.stages.postprocess.ml_flow_drift_stage.MLFlowDriftStage", modes=NOT_AE) add_command("monitor", "morpheus.stages.general.monitor_stage.MonitorStage", modes=ALL) add_command("preprocess", "morpheus.stages.preprocess.preprocess_ae_stage.PreprocessAEStage", modes=AE_ONLY) diff --git a/morpheus/cli/register_stage.py b/morpheus/cli/register_stage.py index 49e01c7248..a4c6a74239 100644 --- a/morpheus/cli/register_stage.py +++ b/morpheus/cli/register_stage.py @@ -211,6 +211,11 @@ def set_options_param_type(options_kwargs: dict, annotation, doc_type: str): else: options_kwargs["type"] = annotation + elif (issubtype(annotation, dict)): + options_kwargs["multiple"] = True + options_kwargs["type"] = click.Tuple([str, str]) + options_kwargs["callback"] = lambda ctx, param, value: dict(value) + else: options_kwargs["type"] = annotation @@ -265,7 +270,10 @@ def register_stage_inner(stage_class: _DecoratorType) -> _DecoratorType: nonlocal command_name - if (not hasattr(stage_class, "_morpheus_registered_stage")): + # A subclass of a stage that is already registered with the CLI will already have this attribute set, + # but the command name won't match. + if (not hasattr(stage_class, "_morpheus_registered_stage") + or stage_class._morpheus_registered_stage.name != command_name): # Determine the command name if it wasnt supplied if (command_name is None): diff --git a/morpheus/messages/multi_tensor_message.py b/morpheus/messages/multi_tensor_message.py index 952ea45a2c..1c07225a6f 100644 --- a/morpheus/messages/multi_tensor_message.py +++ b/morpheus/messages/multi_tensor_message.py @@ -113,7 +113,7 @@ def __getattr__(self, name: str) -> typing.Any: if hasattr(super(), "__getattr__"): return super().__getattr__(name) - raise AttributeError + raise AttributeError(f'No attribute named "{name}" exists') def _check_id_tensor(self): diff --git a/morpheus/stages/inference/triton_inference_stage.py b/morpheus/stages/inference/triton_inference_stage.py index 8caf42c91e..e15fcfc9b0 100644 --- a/morpheus/stages/inference/triton_inference_stage.py +++ b/morpheus/stages/inference/triton_inference_stage.py @@ -18,7 +18,6 @@ import queue import typing import warnings -from abc import abstractmethod from functools import lru_cache from functools import partial @@ -203,10 +202,11 @@ class InputWrapper: """ - def __init__(self, - client: tritonclient.InferenceServerClient, - model_name: str, - config: typing.Dict[str, TritonInOut]): + def __init__( + self, + client: tritonclient.InferenceServerClient, # pylint: disable=unused-argument + model_name: str, + config: typing.Dict[str, TritonInOut]): self._config = config.copy() self._total_bytes = 0 @@ -357,7 +357,7 @@ def __init__(self, super().__init__(client, model_name, config) # Now create the necessary shared memory bits - self.region_name = model_name + "_{}".format(ShmInputWrapper.total_count) + self.region_name = f"{model_name}_{ShmInputWrapper.total_count}" ShmInputWrapper.total_count += 1 # Allocate the total memory @@ -368,7 +368,7 @@ def __init__(self, self._config[key].ptr = cp.cuda.MemoryPointer(self._memory, self._config[key].offset) # Now get the registered IPC handle - self._ipc_handle = cp.cuda.runtime.ipcGetMemHandle(self._memory.ptr) + self._ipc_handle = cp.cuda.runtime.ipcGetMemHandle(self._memory.ptr) # pylint: disable=c-extension-no-member # Finally, regester this memory with the server. Must be base64 for some reason??? client.register_cuda_shared_memory(self.region_name, base64.b64encode(self._ipc_handle), 0, self._total_bytes) @@ -405,9 +405,9 @@ def build_input(self, name: str, data: cp.ndarray, force_convert_inputs: bool) - # This class is exclusively run in the worker thread. Separating the classes helps keeps the threads separate -class _TritonInferenceWorker(InferenceWorker): +class TritonInferenceWorker(InferenceWorker): """ - This is a base class for all Triton inference server requests. + Inference worker class for all Triton inference server requests. Parameters ---------- @@ -424,12 +424,14 @@ class _TritonInferenceWorker(InferenceWorker): Whether or not to convert the inputs to the type specified by Triton. This will happen automatically if no data would be lost in the conversion (i.e., float -> double). Set this to True to convert the input even if data would be lost (i.e., double -> float). - inout_mapping : typing.Dict[str, str] + inout_mapping : dict[str, str] Dictionary used to map pipeline input/output names to Triton input/output names. Use this if the Morpheus names do not match the model. use_shared_memory: bool, default = False Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine. + needs_logits : bool, default = False + Determines whether a logits calculation is needed for the value returned by the Triton inference response. """ def __init__(self, @@ -438,28 +440,22 @@ def __init__(self, model_name: str, server_url: str, force_convert_inputs: bool, - inout_mapping: typing.Dict[str, str] = None, - use_shared_memory: bool = False): + inout_mapping: dict[str, str] = None, + use_shared_memory: bool = False, + needs_logits: bool = False): super().__init__(inf_queue) - # Combine the class defaults with any user supplied ones - default_mapping = type(self).default_inout_mapping() - - default_mapping.update(inout_mapping if inout_mapping is not None else {}) - self._model_name = model_name self._server_url = server_url - self._inout_mapping = default_mapping + self._inout_mapping = inout_mapping or {} self._use_shared_memory = use_shared_memory - self._requires_seg_ids = False - self._max_batch_size = c.model_max_batch_size self._fea_length = c.feature_length self._force_convert_inputs = force_convert_inputs # Whether or not the returned value needs a logits calc for the response - self._needs_logits = type(self).needs_logits() + self._needs_logits = needs_logits self._inputs: typing.Dict[str, TritonInOut] = {} self._outputs: typing.Dict[str, TritonInOut] = {} @@ -472,18 +468,13 @@ def supports_cpp_node(cls): # Enable support by default return True - @classmethod - def needs_logits(cls): - return False - - @classmethod - def default_inout_mapping(cls) -> typing.Dict[str, str]: - return {} + @property + def needs_logits(self) -> bool: + return self._needs_logits def init(self): """ This function instantiate triton client and memory allocation for inference input and output. - """ self._triton_client = tritonclient.InferenceServerClient(url=self._server_url, verbose=False) @@ -504,29 +495,29 @@ def init(self): # Make sure the inputs/outputs match our config if (int(model_meta["inputs"][0]["shape"][-1]) != self._fea_length): - raise RuntimeError("Mismatched Sequence Length. Config specified {} but model specified {}".format( - self._fea_length, int(model_meta["inputs"][0]["shape"][-1]))) + raise RuntimeError(f"Mismatched Sequence Length. Config specified {self._fea_length} but model" + f" specified {int(model_meta['inputs'][0]['shape'][-1])}") # Check batch size if (model_config.get("max_batch_size", 0) != self._max_batch_size): # If the model is more, thats fine. Gen warning if (model_config["max_batch_size"] > self._max_batch_size): - warnings.warn(("Model max batch size ({}) is more than configured max batch size ({}). " - "May result in sub optimal performance").format(model_config["max_batch_size"], - self._max_batch_size)) + warnings.warn( + f"Model max batch size ({model_config['max_batch_size']}) is more than configured max batch " + f"size ({self._max_batch_size}). May result in sub optimal performance") # If the model is less, raise error. Cant send more to Triton than the max batch size if (model_config["max_batch_size"] < self._max_batch_size): raise RuntimeError( - ("Model max batch size ({}) is less than configured max batch size ({}). " - "Reduce max batch size to be less than or equal to model max batch size.").format( - model_config["max_batch_size"], self._max_batch_size)) + f"Model max batch size ({model_config['max_batch_size']}) is less than configured max batch" + f" size ({self._max_batch_size}). Reduce max batch size to be less than or equal to model max" + " batch size.") shm_config = {} def build_inout(x: dict): - b = np.dtype(triton_to_np_dtype(x["datatype"])).itemsize + num_bytes = np.dtype(triton_to_np_dtype(x["datatype"])).itemsize shape = [] @@ -538,12 +529,12 @@ def build_inout(x: dict): shape.append(y_int) - b *= y_int + num_bytes *= y_int mapped_name = x["name"] if x["name"] not in self._inout_mapping else self._inout_mapping[x["name"]] return TritonInOut(name=x["name"], - bytes=b, + bytes=num_bytes, datatype=x["datatype"], shape=shape, mapped_name=mapped_name) @@ -573,17 +564,34 @@ def create_wrapper(): self._mem_pool = ResourcePool(create_fn=create_wrapper, max_size=1000) except InferenceServerException as ex: - logger.exception("Exception occurred while coordinating with Triton. Exception message: \n{}\n".format(ex), + logger.exception("Exception occurred while coordinating with Triton. Exception message: \n%s\n", + ex, exc_info=ex) raise ex def calc_output_dims(self, x: MultiInferenceMessage) -> typing.Tuple: return (x.count, self._outputs[list(self._outputs.keys())[0]].shape[1]) - @abstractmethod - def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> TensorMemory: - pass + def _build_response( + self, + batch: MultiInferenceMessage, # pylint: disable=unused-argument + result: tritonclient.InferResult) -> TensorMemory: + output = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} + # Make sure we have at least 2 dims + for key, val in output.items(): + if (len(val.shape) == 1): + output[key] = np.expand_dims(val, 1) + + if (self._needs_logits): + output = {key: 1.0 / (1.0 + np.exp(-val)) for key, val in output.items()} + + return TensorMemory( + count=output["probs"].shape[0], + tensors={'probs': cp.array(output["probs"])} # For now, only support one output + ) + + # pylint: disable=invalid-name def _infer_callback(self, cb: typing.Callable[[TensorMemory], None], m: InputWrapper, @@ -603,7 +611,9 @@ def _infer_callback(self, self._mem_pool.return_obj(m) - def process(self, batch: MultiInferenceMessage, cb: typing.Callable[[TensorMemory], None]): + # pylint: enable=invalid-name + + def process(self, batch: MultiInferenceMessage, callback: typing.Callable[[TensorMemory], None]): """ This function sends batch of events as a requests to Triton inference server using triton client API. @@ -611,7 +621,7 @@ def process(self, batch: MultiInferenceMessage, cb: typing.Callable[[TensorMemor ---------- batch : `morpheus.pipeline.messages.MultiInferenceMessage` Mini-batch of inference messages. - cb : typing.Callable[[`morpheus.pipeline.messages.TensorMemory`], None] + callback : typing.Callable[[`morpheus.pipeline.messages.TensorMemory`], None] Callback to set the values for the inference response. """ @@ -628,329 +638,106 @@ def process(self, batch: MultiInferenceMessage, cb: typing.Callable[[TensorMemor # Inference call self._triton_client.async_infer(model_name=self._model_name, inputs=inputs, - callback=partial(self._infer_callback, cb, mem, batch), + callback=partial(self._infer_callback, callback, mem, batch), outputs=outputs) -class TritonInferenceNLP(_TritonInferenceWorker): +@register_stage("inf-triton", modes=[PipelineModes.NLP, PipelineModes.FIL, PipelineModes.OTHER]) +class TritonInferenceStage(InferenceStage): """ - This class extends TritonInference to deal with scenario-specific NLP models inference requests like building - response. + Perform inference with Triton Inference Server. + + This class specifies which inference implementation category (Ex: NLP/FIL) is needed for inferencing. Parameters ---------- - inf_queue : `morpheus.utils.producer_consumer_queue.ProducerConsumerQueue` - Inference queue. c : `morpheus.config.Config` Pipeline configuration instance. model_name : str - Name of the model specifies which model can handle the inference requests that are sent to Triton - inference server. + Name of the model specifies which model can handle the inference requests that are sent to Triton inference + server. server_url : str - Triton server gRPC URL including the port. + Triton server URL. force_convert_inputs : bool, default = False - Whether or not to convert the inputs to the type specified by Triton. This will happen automatically if no - data would be lost in the conversion (i.e., float -> double). Set this to True to convert the input even if - data would be lost (i.e., double -> float). - use_shared_memory : bool, default = False + Instructs the stage to convert the incoming data to the same format that Triton is expecting. If set to False, + data will only be converted if it would not result in the loss of data. + use_shared_memory : bool, default = False, is_flag = True Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network transfer time but requires that Morpheus and Triton are located on the same machine. - inout_mapping : typing.Dict[str, str] - Dictionary used to map pipeline input/output names to Triton input/output names. Use this if the - Morpheus names do not match the model. - - """ - - def __init__(self, - inf_queue: ProducerConsumerQueue, - c: Config, - model_name: str, - server_url: str, - force_convert_inputs: bool = False, - use_shared_memory: bool = False, - inout_mapping: typing.Dict[str, str] = None): - super().__init__(inf_queue, - c, - model_name=model_name, - server_url=server_url, - force_convert_inputs=force_convert_inputs, - use_shared_memory=use_shared_memory, - inout_mapping=inout_mapping) + needs_logits : bool, optional + Determines whether a logits calculation is needed for the value returned by the Triton inference response. If + undefined, the value will be inferred based on the pipeline mode, defaulting to `True` for NLP and `False` for + other modes. + inout_mapping : dict[str, str], optional + Dictionary used to map pipeline input/output names to Triton input/output names. + Use this if the Morpheus names do not match the model. + If undefined, a default mapping will be used based on the pipeline mode as follows: - @classmethod - def needs_logits(cls): - """ - Determines whether a logits calculation is needed for the value returned by the Triton inference response. - """ - return True + * `FIL`: `{"output__0": "probs"}` - @classmethod - def default_inout_mapping(cls) -> typing.Dict[str, str]: - """ - Returns default dictionary used to map NLP pipeline input/output names to Triton input/output names + * `NLP`: `{"attention_mask": "input_mask", "output": "probs"}` - Returns - ------- - default_inout_mapping : typing.Dict[str, str] - Dictionary with default input and output names. - """ + * All other modes: `{}` - # Some models use different names for the same thing. Set that here but allow user customization - return { - "attention_mask": "input_mask", - "output": "probs", - } - - def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> TensorMemory: - - output = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} - - if (self._needs_logits): - output = {key: 1.0 / (1.0 + np.exp(-val)) for key, val in output.items()} - - mem = TensorMemory( - count=output["probs"].shape[0], - tensors={'probs': cp.array(output["probs"])} # For now, only support one output - ) + From the command line this can be specified multiple times for each key/value pair, for example: - return mem + --inout-mapping mask input_mask --inout-mapping output probs + which will be inroduced as: -class TritonInferenceFIL(_TritonInferenceWorker): + inout_mapping={"mask": "input_mask", "output": "probs"} """ - This class extends `TritonInference` to deal with scenario-specific FIL models inference requests like - building response. - Parameters - ---------- - inf_queue : `morpheus.utils.producer_consumer_queue.ProducerConsumerQueue` - Inference queue. - c : `morpheus.config.Config` - Pipeline configuration instance. - model_name : str - Name of the model specifies which model can handle the inference requests that are sent to Triton - inference server. - server_url : str - Triton server gRPC URL including the port. - force_convert_inputs : bool, default = False - Whether or not to convert the inputs to the type specified by Triton. This will happen automatically if no - data would be lost in the conversion (i.e., float -> double). Set this to True to convert the input even if - data would be lost (i.e., double -> float). - use_shared_memory: bool, default = False - Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network - transfer time but requires that Morpheus and Triton are located on the same machine. - inout_mapping : typing.Dict[str, str] - Dictionary used to map pipeline input/output names to Triton input/output names. Use this if the - Morpheus names do not match the model. - - """ - - def __init__(self, - inf_queue: ProducerConsumerQueue, - c: Config, - model_name: str, - server_url: str, - force_convert_inputs: bool = False, - use_shared_memory: bool = False, - inout_mapping: typing.Dict[str, str] = None): - super().__init__(inf_queue, - c, - model_name=model_name, - server_url=server_url, - force_convert_inputs=force_convert_inputs, - use_shared_memory=use_shared_memory, - inout_mapping=inout_mapping) - - @classmethod - def default_inout_mapping(cls) -> typing.Dict[str, str]: - """ - Returns default dictionary used to map FIL pipeline input/output names to Triton input/output names - - Returns - ------- - default_inout_mapping : typing.Dict[str, str] - Dictionary with default input and output names. - """ - # Some models use different names for the same thing. Set that here but allow user customization - return { + _INFERENCE_WORKER_DEFAULT_INOUT_MAPPING = { + PipelineModes.FIL: { "output__0": "probs", + }, + PipelineModes.NLP: { + "attention_mask": "input_mask", + "output": "probs", } - - def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> TensorMemory: - - output = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} - - for key, val in output.items(): - if (len(val.shape) == 1): - output[key] = np.expand_dims(val, 1) - - mem = TensorMemory( - count=output["probs"].shape[0], - tensors={'probs': cp.array(output["probs"])} # For now, only support one output - ) - - return mem - - -class TritonInferenceAE(_TritonInferenceWorker): - """ - This class extends `TritonInference` to deal with inference processing specific to the AutoEncoder. - - Parameters - ---------- - inf_queue : `morpheus.utils.producer_consumer_queue.ProducerConsumerQueue` - Inference queue. - c : `morpheus.config.Config` - Pipeline configuration instance. - model_name : str - Name of the model specifies which model can handle the inference requests that are sent to Triton - inference server. - server_url : str - Triton server gRPC URL including the port. - force_convert_inputs : bool, default = False - Whether or not to convert the inputs to the type specified by Triton. This will happen automatically if no - data would be lost in the conversion (i.e., float -> double). Set this to True to convert the input even if - data would be lost (i.e., double -> float). - use_shared_memory: bool, default = False - Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network - transfer time but requires that Morpheus and Triton are located on the same machine. - inout_mapping : typing.Dict[str, str] - Dictionary used to map pipeline input/output names to Triton input/output names. Use this if the - Morpheus names do not match the model. - - """ + } def __init__(self, - inf_queue: ProducerConsumerQueue, c: Config, model_name: str, server_url: str, force_convert_inputs: bool = False, use_shared_memory: bool = False, - inout_mapping: typing.Dict[str, str] = None): - super().__init__(inf_queue, - c, - model_name=model_name, - server_url=server_url, - force_convert_inputs=force_convert_inputs, - use_shared_memory=use_shared_memory, - inout_mapping=inout_mapping) - - import torch - - from morpheus.models.dfencoder import AutoEncoder - - # Save the autoencoder path - with open(c.ae.autoencoder_path, 'rb') as in_strm: - self._autoencoder = AutoEncoder() - self._autoencoder.load_state_dict(torch.load(in_strm)) - - # Ensure that there is a label_smoothing property on cce. Necessary if pytorch version is different - if (not hasattr(self._autoencoder.cce, "label_smoothing")): - self._autoencoder.cce.label_smoothing = 0.0 - - @classmethod - def supports_cpp_node(cls): - # Enable support by default - return False - - def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> TensorMemory: - - import torch - - output = {output.mapped_name: result.as_numpy(output.name) for output in self._outputs.values()} - - data = self._autoencoder.prepare_df(batch.get_meta()) - num_target, bin_target, codes = self._autoencoder.compute_targets(data) - mse_loss = self._autoencoder.mse(torch.as_tensor(output["num"], device='cuda'), num_target) - net_loss = [mse_loss.data] - bce_loss = self._autoencoder.bce(torch.as_tensor(output["bin"], device='cuda'), bin_target) - net_loss += [bce_loss.data] - cce_loss = [] - for i, ft in enumerate(self._autoencoder.categorical_fts): - loss = self._autoencoder.cce(torch.as_tensor(output[ft], device='cuda'), codes[i]) - cce_loss.append(loss) - net_loss += [loss.data.reshape(-1, 1)] - - net_loss = torch.cat(net_loss, dim=1).mean(dim=1) - ae_scores = cp.asarray(net_loss) - ae_scores = ae_scores.reshape((batch.count, 1)) - - mem = TensorMemory( - count=batch.count, - tensors={'probs': ae_scores} # For now, only support one output - ) - - return mem - - -@register_stage("inf-triton") -class TritonInferenceStage(InferenceStage): - """ - Perform inference with Triton Inference Server. - - This class specifies which inference implementation category (Ex: NLP/FIL) is needed for inferencing. - - Parameters - ---------- - c : `morpheus.config.Config` - Pipeline configuration instance. - model_name : str - Name of the model specifies which model can handle the inference requests that are sent to Triton inference - server. - server_url : str - Triton server URL. - force_convert_inputs : bool, default = False - Instructs the stage to convert the incoming data to the same format that Triton is expecting. If set to False, - data will only be converted if it would not result in the loss of data. - use_shared_memory : bool, default = False, is_flag = True - Whether or not to use CUDA Shared IPC Memory for transferring data to Triton. Using CUDA IPC reduces network - transfer time but requires that Morpheus and Triton are located on the same machine. - """ - - def __init__(self, - c: Config, - model_name: str, - server_url: str, - force_convert_inputs: bool = False, - use_shared_memory: bool = False): + needs_logits: bool = None, + inout_mapping: dict[str, str] = None): super().__init__(c) self._config = c + if needs_logits is None: + needs_logits = c.mode == PipelineModes.NLP + + # Combine the pipeline mode defaults with any user supplied ones + inout_mapping_ = self._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(c.mode, {}) + if inout_mapping is not None: + inout_mapping_.update(inout_mapping) + self._kwargs = { "model_name": model_name, "server_url": server_url, "force_convert_inputs": force_convert_inputs, "use_shared_memory": use_shared_memory, + "inout_mapping": inout_mapping_, + "needs_logits": needs_logits } - self._requires_seg_ids = False - - def supports_cpp_node(self): + def supports_cpp_node(self) -> bool: # Get the value from the worker class - return self._get_worker_class().supports_cpp_node() - - def _get_worker_class(self): - if (self._config.mode == PipelineModes.NLP): - return TritonInferenceNLP - elif (self._config.mode == PipelineModes.FIL): - return TritonInferenceFIL - elif (self._config.mode == PipelineModes.AE): - return TritonInferenceAE - else: - raise NotImplementedError("Unknown config mode") - - def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> InferenceWorker: + return TritonInferenceWorker.supports_cpp_node() - worker_cls = self._get_worker_class() - - return worker_cls(inf_queue=inf_queue, c=self._config, **self._kwargs) + def _get_inference_worker(self, inf_queue: ProducerConsumerQueue) -> TritonInferenceWorker: + """ + Returns the worker for this stage. Authors of custom sub-classes can override this method to provide a custom + worker. + """ - def _get_cpp_inference_node(self, builder: mrc.Builder): + return TritonInferenceWorker(inf_queue=inf_queue, c=self._config, **self._kwargs) - return _stages.InferenceClientStage(builder, - name=self.unique_name, - needs_logits=self._get_worker_class().needs_logits(), - inout_mapping=self._get_worker_class().default_inout_mapping(), - **self._kwargs) + def _get_cpp_inference_node(self, builder: mrc.Builder) -> mrc.SegmentObject: + return _stages.InferenceClientStage(builder, name=self.unique_name, **self._kwargs) diff --git a/tests/examples/log_parsing/conftest.py b/tests/examples/log_parsing/conftest.py index a9d9c21f5e..7af1178939 100644 --- a/tests/examples/log_parsing/conftest.py +++ b/tests/examples/log_parsing/conftest.py @@ -16,8 +16,8 @@ import pytest -@pytest.fixture -def config(config): # pylint: disable=redefined-outer-name +@pytest.fixture(name="config") +def config_fixture(config): """ The log_parsing pipelie requires NLP mode. Set this here so all the tests don't need to set it themselves. """ diff --git a/tests/examples/log_parsing/test_inference.py b/tests/examples/log_parsing/test_inference.py index 3325c45d71..928689aa5d 100644 --- a/tests/examples/log_parsing/test_inference.py +++ b/tests/examples/log_parsing/test_inference.py @@ -20,21 +20,22 @@ import cupy as cp import numpy as np -import pandas as pd import pytest -import cudf - from _utils import TEST_DIRS from morpheus.config import Config +from morpheus.config import PipelineModes from morpheus.messages import InferenceMemoryNLP from morpheus.messages import MessageMeta -from morpheus.messages import MultiInferenceMessage -from morpheus.stages.inference.triton_inference_stage import _TritonInferenceWorker +from morpheus.messages import MultiInferenceNLPMessage +from morpheus.messages import MultiResponseMessage +from morpheus.messages import TensorMemory +from morpheus.stages.inference.triton_inference_stage import TritonInferenceWorker from morpheus.utils.producer_consumer_queue import ProducerConsumerQueue +from morpheus.utils.type_aliases import DataFrameType -def build_response_mem(messages_mod, log_test_data_dir: str): +def build_response_mem(log_test_data_dir: str) -> TensorMemory: # we have tensor data for the first five rows count = 5 tensors = {} @@ -43,15 +44,33 @@ def build_response_mem(messages_mod, log_test_data_dir: str): host_data = np.loadtxt(tensor_file, delimiter=',') tensors[tensor_name] = cp.asarray(host_data) - return messages_mod.ResponseMemoryLogParsing(count=count, **tensors) + return TensorMemory(count=count, tensors=tensors) + + +def build_resp_message(df: DataFrameType, num_cols: int = 2) -> MultiResponseMessage: + count = len(df) + seq_ids = cp.zeros((count, 3), dtype=cp.uint32) + seq_ids[:, 0] = cp.arange(0, count, dtype=cp.uint32) + seq_ids[:, 2] = 42 + + meta = MessageMeta(df) + mem = TensorMemory(count=count, + tensors={ + 'confidences': cp.zeros((count, num_cols)), + 'labels': cp.zeros((count, num_cols)), + 'input_ids': cp.zeros((count, num_cols), dtype=cp.float32), + 'seq_ids': seq_ids + }) + + return MultiResponseMessage(meta=meta, mess_offset=0, mess_count=count, memory=mem, offset=0, count=count) -def build_inf_message(df: typing.Union[pd.DataFrame, cudf.DataFrame], +def build_inf_message(df: DataFrameType, mess_offset: int, mess_count: int, offset: int, count: int, - num_cols: int = 2) -> MultiInferenceMessage: + num_cols: int = 2) -> MultiInferenceNLPMessage: assert count >= mess_count tensor_length = offset + count seq_ids = cp.zeros((tensor_length, 3), dtype=cp.uint32) @@ -69,28 +88,25 @@ def build_inf_message(df: typing.Union[pd.DataFrame, cudf.DataFrame], input_mask=cp.zeros((tensor_length, num_cols), dtype=cp.float32), seq_ids=seq_ids) - return MultiInferenceMessage(meta=meta, - mess_offset=mess_offset, - mess_count=mess_count, - memory=mem, - offset=offset, - count=count) + return MultiInferenceNLPMessage(meta=meta, + mess_offset=mess_offset, + mess_count=mess_count, + memory=mem, + offset=offset, + count=count) -def _check_worker(inference_mod: types.ModuleType, worker: _TritonInferenceWorker): - assert isinstance(worker, _TritonInferenceWorker) +def _check_worker(inference_mod: types.ModuleType, worker: TritonInferenceWorker, expected_mapping: dict[str, str]): + assert isinstance(worker, TritonInferenceWorker) assert isinstance(worker, inference_mod.TritonInferenceLogParsing) assert worker._model_name == 'test_model' assert worker._server_url == 'test_server' assert not worker._force_convert_inputs assert not worker._use_shared_memory - - expected_mapping = inference_mod.TritonInferenceLogParsing.default_inout_mapping() - expected_mapping.update({'test': 'this'}) + assert worker.needs_logits assert worker._inout_mapping == expected_mapping -@pytest.mark.use_python @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) def test_log_parsing_triton_inference_log_parsing_constructor(config: Config, import_mod: typing.List[types.ModuleType]): @@ -101,17 +117,16 @@ def test_log_parsing_triton_inference_log_parsing_constructor(config: Config, server_url='test_server', force_convert_inputs=False, use_shared_memory=False, - inout_mapping={'test': 'this'}) + inout_mapping={'test': 'this'}, + needs_logits=True) - _check_worker(inference_mod, worker) + _check_worker(inference_mod, worker, {'test': 'this'}) -@pytest.mark.use_python @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) @pytest.mark.parametrize("mess_offset,mess_count,offset,count", [(0, 20, 0, 20), (5, 10, 5, 10)]) def test_log_parsing_triton_inference_log_parsing_build_output_message(config: Config, - filter_probs_df: typing.Union[pd.DataFrame, - cudf.DataFrame], + filter_probs_df: DataFrameType, import_mod: typing.List[types.ModuleType], mess_offset: int, mess_count: int, @@ -143,16 +158,31 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C assert msg.count == count assert set(msg.memory.tensor_names).issuperset(('confidences', 'labels', 'input_ids', 'seq_ids')) - assert msg.confidences.shape == (count, 2) - assert msg.labels.shape == (count, 2) - assert msg.input_ids.shape == (count, 2) - assert msg.seq_ids.shape == (count, 3) + assert msg.get_tensor('confidences').shape == (count, 2) + assert msg.get_tensor('labels').shape == (count, 2) + assert msg.get_tensor('input_ids').shape == (count, 2) + assert msg.get_tensor('seq_ids').shape == (count, 3) -@pytest.mark.use_python @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) def test_log_parsing_inference_stage_constructor(config: Config, import_mod: typing.List[types.ModuleType]): inference_mod = import_mod[0] + + expected_kwargs = { + "model_name": + 'test_model', + "server_url": + 'test_server', + "force_convert_inputs": + False, + "use_shared_memory": + False, + "needs_logits": + True, + "inout_mapping": + inference_mod.LogParsingInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(PipelineModes.NLP, {}), + } + stage = inference_mod.LogParsingInferenceStage( config, model_name='test_model', @@ -162,63 +192,47 @@ def test_log_parsing_inference_stage_constructor(config: Config, import_mod: typ ) assert stage._config is config - assert stage._kwargs == { - "model_name": 'test_model', - "server_url": 'test_server', - "force_convert_inputs": False, - "use_shared_memory": False - } + assert stage._kwargs == expected_kwargs - # Intentionally not checking the `_requires_seg_ids` value at it appears to not be used - -@pytest.mark.use_python @pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')]) def test_log_parsing_inference_stage_get_inference_worker(config: Config, import_mod: typing.List[types.ModuleType]): inference_mod = import_mod[0] - stage = inference_mod.LogParsingInferenceStage( - config, - model_name='test_model', - server_url='test_server', - force_convert_inputs=False, - use_shared_memory=False, - ) + stage = inference_mod.LogParsingInferenceStage(config, + model_name='test_model', + server_url='test_server', + force_convert_inputs=False, + use_shared_memory=False, + inout_mapping={'test': 'this'}) - stage._kwargs.update({'inout_mapping': {'test': 'this'}}) + expected_mapping = inference_mod.LogParsingInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get( + PipelineModes.NLP, {}) + expected_mapping.update({'test': 'this'}) worker = stage._get_inference_worker(inf_queue=ProducerConsumerQueue()) - _check_worker(inference_mod, worker) + _check_worker(inference_mod, worker, expected_mapping) -@pytest.mark.use_python +@pytest.mark.use_cudf @pytest.mark.usefixtures("manual_seed", "config") -@pytest.mark.import_mod([ - os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py'), - os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'messages.py') -]) +@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')) @pytest.mark.parametrize("mess_offset,mess_count,offset,count", [(0, 5, 0, 5), (5, 5, 0, 5)]) def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.List[types.ModuleType], - filter_probs_df: typing.Union[pd.DataFrame, cudf.DataFrame], + filter_probs_df: DataFrameType, mess_offset, mess_count, offset, count): - inference_mod, messages_mod = import_mod - - ttl_count = len(filter_probs_df) + inference_mod = import_mod - input_res = build_response_mem(messages_mod, os.path.join(TEST_DIRS.tests_data_dir, 'examples/log_parsing')) + input_res = build_response_mem(os.path.join(TEST_DIRS.tests_data_dir, 'examples/log_parsing')) # confidences, labels & input_ids all have the same shape - num_cols = input_res.confidences.shape[1] - input_mem = messages_mod.PostprocMemoryLogParsing( - count=ttl_count, - confidences=cp.zeros((ttl_count, num_cols), dtype=cp.float32), - input_ids=cp.zeros((ttl_count, num_cols), dtype=cp.float32), - labels=cp.zeros((ttl_count, num_cols), dtype=cp.float32), - seq_ids=cp.zeros((ttl_count, 3), dtype=cp.uint32), - ) + num_cols = input_res.get_tensor('confidences').shape[1] + resp_msg = build_resp_message(filter_probs_df, num_cols=num_cols) + + orig_tensors = {k: v.copy() for (k, v) in resp_msg.memory.get_tensors().items()} input_inf = build_inf_message(filter_probs_df, mess_offset=mess_offset, @@ -227,23 +241,25 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis count=count, num_cols=num_cols) - output_msg = inference_mod.LogParsingInferenceStage._convert_one_response(input_mem, input_inf, input_res) + output_msg = inference_mod.LogParsingInferenceStage._convert_one_response(resp_msg, input_inf, input_res) - assert isinstance(output_msg, messages_mod.MultiPostprocLogParsingMessage) + assert isinstance(output_msg, MultiResponseMessage) assert output_msg.meta is input_inf.meta - assert output_msg.memory is input_mem assert output_msg.mess_offset == mess_offset assert output_msg.mess_count == mess_count assert output_msg.offset == offset assert output_msg.count == count - assert (output_msg.seq_ids == input_inf.seq_ids).all() - assert (output_msg.input_ids == input_inf.input_ids).all() - assert (output_msg.confidences == input_res.confidences).all() - assert (output_msg.labels == input_res.labels).all() + assert (output_msg.get_tensor('seq_ids') == input_inf.get_tensor('seq_ids')).all() + assert (output_msg.get_tensor('input_ids') == input_inf.get_tensor('input_ids')).all() + assert (output_msg.get_tensor('confidences') == input_res.get_tensor('confidences')).all() + assert (output_msg.get_tensor('labels') == input_res.get_tensor('labels')).all() # Ensure we didn't write to the memory outside of the [offset:offset+count] bounds - tensors = input_mem.get_tensors() + tensors = resp_msg.memory.get_tensors() for (tensor_name, tensor) in tensors.items(): - assert (tensor[0:offset] == 0).all(), f"Out of bounds values for {tensor_name}" - assert (tensor[offset + count:] == 0).all(), f"Out of bounds values for {tensor_name}" + orig_tensor = orig_tensors[tensor_name] + + error_msg = f"Out of bounds values for {tensor_name}" + assert (tensor[0:offset] == orig_tensor[0:offset]).all(), error_msg + assert (tensor[offset + count:] == orig_tensor[offset + count:]).all(), error_msg diff --git a/tests/examples/log_parsing/test_log_parsing_pipe.py b/tests/examples/log_parsing/test_log_parsing_pipe.py index d091cada6f..7d91d5496d 100755 --- a/tests/examples/log_parsing/test_log_parsing_pipe.py +++ b/tests/examples/log_parsing/test_log_parsing_pipe.py @@ -126,7 +126,6 @@ def _run_mocked_pipeline(config: Config, dataset_cudf: DatasetManager, import_mo @pytest.mark.slow -@pytest.mark.use_python @pytest.mark.import_mod([ os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py'), os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'postprocessing.py') diff --git a/tests/examples/log_parsing/test_postprocessing.py b/tests/examples/log_parsing/test_postprocessing.py index 8627eb1e1c..49d108d2c3 100644 --- a/tests/examples/log_parsing/test_postprocessing.py +++ b/tests/examples/log_parsing/test_postprocessing.py @@ -25,9 +25,11 @@ from _utils.dataset_manager import DatasetManager from morpheus.config import Config from morpheus.messages import MessageMeta +from morpheus.messages import MultiResponseMessage +from morpheus.messages import TensorMemory -def build_post_proc_message(messages_mod, dataset_cudf: DatasetManager, log_test_data_dir: str): +def build_post_proc_message(dataset_cudf: DatasetManager, log_test_data_dir: str): input_file = os.path.join(TEST_DIRS.validation_data_dir, 'log-parsing-validation-data-input.csv') input_df = dataset_cudf[input_file] meta = MessageMeta(input_df) @@ -35,29 +37,26 @@ def build_post_proc_message(messages_mod, dataset_cudf: DatasetManager, log_test # we have tensor data for the first five rows count = 5 tensors = {} - for tensor_name in ['confidences', 'input_ids', 'labels', 'seq_ids']: + for tensor_name in ['confidences', 'input_ids', 'labels']: tensor_file = os.path.join(log_test_data_dir, f'{tensor_name}.csv') host_data = np.loadtxt(tensor_file, delimiter=',') tensors[tensor_name] = cp.asarray(host_data) - memory = messages_mod.PostprocMemoryLogParsing(count=5, **tensors) - return messages_mod.MultiPostprocLogParsingMessage(meta=meta, - mess_offset=0, - mess_count=count, - memory=memory, - offset=0, - count=count) + host__seq_data = np.loadtxt(os.path.join(log_test_data_dir, 'seq_ids.csv'), delimiter=',') + seq_ids = cp.zeros((count, 3), dtype=cp.uint32) + seq_ids[:, 0] = cp.arange(0, 5, dtype=cp.uint32) + seq_ids[:, 2] = cp.asarray(host__seq_data)[:, 2] + tensors['seq_ids'] = seq_ids + memory = TensorMemory(count=5, tensors=tensors) + return MultiResponseMessage(meta=meta, mess_offset=0, mess_count=count, memory=memory, offset=0, count=count) -@pytest.mark.use_python -@pytest.mark.import_mod([ - os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'messages.py'), - os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'postprocessing.py') -]) + +@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'postprocessing.py')) def test_log_parsing_post_processing_stage(config: Config, dataset_cudf: DatasetManager, import_mod: typing.List[types.ModuleType]): - messages_mod, postprocessing_mod = import_mod + postprocessing_mod = import_mod model_vocab_file = os.path.join(TEST_DIRS.data_dir, 'bert-base-cased-vocab.txt') log_test_data_dir = os.path.join(TEST_DIRS.tests_data_dir, 'examples/log_parsing') @@ -67,10 +66,10 @@ def test_log_parsing_post_processing_stage(config: Config, vocab_path=model_vocab_file, model_config_path=model_config_file) - post_proc_message = build_post_proc_message(messages_mod, dataset_cudf, log_test_data_dir) + post_proc_message = build_post_proc_message(dataset_cudf, log_test_data_dir) expected_df = dataset_cudf.pandas[os.path.join(log_test_data_dir, 'expected_out.csv')] out_meta = stage._postprocess(post_proc_message) assert isinstance(out_meta, MessageMeta) - DatasetManager.assert_compare_df(out_meta._df, expected_df) + DatasetManager.assert_compare_df(out_meta.df, expected_df) diff --git a/tests/test_cli.py b/tests/test_cli.py index aef467bf84..50cbe44c7e 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -288,8 +288,8 @@ def test_pipeline_ae_all(self, callback_values): 'preprocess', 'inf-pytorch', 'add-scores' - ] + INF_TRITON_ARGS + ['timeseries', '--resolution=1m', '--zscore_threshold=8.0', '--hot_start'] + - MONITOR_ARGS + VALIDATE_ARGS + ['serialize'] + TO_FILE_ARGS + TO_KAFKA_ARGS) + ] + ['timeseries', '--resolution=1m', '--zscore_threshold=8.0', '--hot_start'] + MONITOR_ARGS + VALIDATE_ARGS + + ['serialize'] + TO_FILE_ARGS + TO_KAFKA_ARGS) runner = CliRunner() result = runner.invoke(commands.cli, args) @@ -307,7 +307,6 @@ def test_pipeline_ae_all(self, callback_values): process_ae, auto_enc, add_scores, - triton_inf, time_series, monitor, validation, @@ -331,11 +330,6 @@ def test_pipeline_ae_all(self, callback_values): assert isinstance(auto_enc, AutoEncoderInferenceStage) assert isinstance(add_scores, AddScoresStage) - assert isinstance(triton_inf, TritonInferenceStage) - assert triton_inf._kwargs['model_name'] == 'test-model' - assert triton_inf._kwargs['server_url'] == 'test:123' - assert triton_inf._kwargs['force_convert_inputs'] - assert isinstance(time_series, TimeSeriesStage) assert time_series._resolution == '1m' assert time_series._zscore_threshold == 8.0 diff --git a/tests/test_triton_inference_stage.py b/tests/test_triton_inference_stage.py index f39fb65b7c..f252cd8d79 100644 --- a/tests/test_triton_inference_stage.py +++ b/tests/test_triton_inference_stage.py @@ -25,11 +25,14 @@ from _utils import assert_results from _utils import mk_async_infer +from morpheus.config import Config from morpheus.config import ConfigFIL from morpheus.config import PipelineModes from morpheus.pipeline import LinearPipeline +from morpheus.stages.inference.triton_inference_stage import ProducerConsumerQueue from morpheus.stages.inference.triton_inference_stage import ResourcePool from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage +from morpheus.stages.inference.triton_inference_stage import TritonInferenceWorker from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage from morpheus.stages.postprocess.add_scores_stage import AddScoresStage @@ -119,6 +122,72 @@ def test_resource_pool_create_raises_error(): assert pool.borrow_obj() == 20 +@pytest.mark.parametrize("pipeline_mode", list(PipelineModes)) +@pytest.mark.parametrize("force_convert_inputs", [True, False]) +@pytest.mark.parametrize("use_shared_memory", [True, False]) +@pytest.mark.parametrize("needs_logits", [True, False, None]) +@pytest.mark.parametrize("inout_mapping", [None, {'unit': 'test'}]) +def test_stage_constructor(config: Config, + pipeline_mode: PipelineModes, + force_convert_inputs: bool, + use_shared_memory: bool, + needs_logits: bool | None, + inout_mapping: dict[str, str] | None): + if needs_logits is None: + expexted_needs_logits = (pipeline_mode == PipelineModes.NLP) + else: + expexted_needs_logits = needs_logits + + expected_inout_mapping = TritonInferenceStage._INFERENCE_WORKER_DEFAULT_INOUT_MAPPING.get(pipeline_mode, {}) + expected_inout_mapping.update(inout_mapping or {}) + + config.mode = pipeline_mode + + stage = TritonInferenceStage(config, + model_name='test', + server_url='test:0000', + force_convert_inputs=force_convert_inputs, + use_shared_memory=use_shared_memory, + needs_logits=needs_logits, + inout_mapping=inout_mapping) + + assert stage._kwargs == { + "model_name": "test", + "server_url": "test:0000", + "force_convert_inputs": force_convert_inputs, + "use_shared_memory": use_shared_memory, + "needs_logits": expexted_needs_logits, + 'inout_mapping': expected_inout_mapping + } + + +@pytest.mark.use_python +@pytest.mark.parametrize("pipeline_mode", list(PipelineModes)) +def test_stage_constructor_worker_class(config: Config, pipeline_mode: PipelineModes): + config.mode = pipeline_mode + stage = TritonInferenceStage(config, model_name='test', server_url='test:0000') + worker = stage._get_inference_worker(ProducerConsumerQueue()) + assert isinstance(worker, TritonInferenceWorker) + + +@pytest.mark.use_python +@pytest.mark.parametrize("pipeline_mode", list(PipelineModes)) +@pytest.mark.parametrize("needs_logits", [True, False, None]) +def test_stage_get_inference_worker(config: Config, pipeline_mode: PipelineModes, needs_logits: bool | None): + if needs_logits is None: + expexted_needs_logits = (pipeline_mode == PipelineModes.NLP) + else: + expexted_needs_logits = needs_logits + + config.mode = pipeline_mode + + stage = TritonInferenceStage(config, model_name='test', server_url='test:0000', needs_logits=needs_logits) + + worker = stage._get_inference_worker(ProducerConsumerQueue()) + assert isinstance(worker, TritonInferenceWorker) + assert worker.needs_logits == expexted_needs_logits + + @pytest.mark.slow @pytest.mark.use_python @pytest.mark.parametrize('num_records', [1000, 2000, 4000])