Skip to content

Commit

Permalink
Fix tests to opperate in C++ mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Dec 26, 2023
1 parent bbc0ee3 commit 01f0fe8
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/examples/log_parsing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from morpheus.config import PipelineModes
from morpheus.messages import InferenceMemoryNLP
from morpheus.messages import MessageMeta
from morpheus.messages import MultiResponseMessage
from morpheus.messages import MultiInferenceNLPMessage
from morpheus.messages import MultiResponseMessage
from morpheus.messages import TensorMemory
from morpheus.stages.inference.triton_inference_stage import TritonInferenceWorker
from morpheus.utils.producer_consumer_queue import ProducerConsumerQueue
Expand Down Expand Up @@ -158,10 +158,10 @@ def test_log_parsing_triton_inference_log_parsing_build_output_message(config: C
assert msg.count == count

assert set(msg.memory.tensor_names).issuperset(('confidences', 'labels', 'input_ids', 'seq_ids'))
assert msg.confidences.shape == (count, 2)
assert msg.labels.shape == (count, 2)
assert msg.input_ids.shape == (count, 2)
assert msg.seq_ids.shape == (count, 3)
assert msg.get_tensor('confidences').shape == (count, 2)
assert msg.get_tensor('labels').shape == (count, 2)
assert msg.get_tensor('input_ids').shape == (count, 2)
assert msg.get_tensor('seq_ids').shape == (count, 3)


@pytest.mark.import_mod([os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py')])
Expand Down Expand Up @@ -214,6 +214,7 @@ def test_log_parsing_inference_stage_get_inference_worker(config: Config, import
_check_worker(inference_mod, worker, expected_mapping)


@pytest.mark.use_cudf
@pytest.mark.usefixtures("manual_seed", "config")
@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'inference.py'))
@pytest.mark.parametrize("mess_offset,mess_count,offset,count", [(0, 5, 0, 5), (5, 5, 0, 5)])
Expand Down Expand Up @@ -249,10 +250,10 @@ def test_log_parsing_inference_stage_convert_one_response(import_mod: typing.Lis
assert output_msg.offset == offset
assert output_msg.count == count

assert (output_msg.seq_ids == input_inf.seq_ids).all()
assert (output_msg.input_ids == input_inf.input_ids).all()
assert (output_msg.confidences == input_res.get_tensor('confidences')).all()
assert (output_msg.labels == input_res.get_tensor('labels')).all()
assert (output_msg.get_tensor('seq_ids') == input_inf.get_tensor('seq_ids')).all()
assert (output_msg.get_tensor('input_ids') == input_inf.get_tensor('input_ids')).all()
assert (output_msg.get_tensor('confidences') == input_res.get_tensor('confidences')).all()
assert (output_msg.get_tensor('labels') == input_res.get_tensor('labels')).all()

# Ensure we didn't write to the memory outside of the [offset:offset+count] bounds
tensors = resp_msg.memory.get_tensors()
Expand Down

0 comments on commit 01f0fe8

Please sign in to comment.