From 55fe93441a2a64dcdb4ecf10de747233d424d4a7 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Thu, 11 May 2023 09:56:43 +0100 Subject: [PATCH] Add support for int8 values to Categorify inference (#1818) --- cpp/nvtabular/inference/categorify.cc | 10 ++++++++++ tests/unit/ops/test_categorify.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/cpp/nvtabular/inference/categorify.cc b/cpp/nvtabular/inference/categorify.cc index 90c3c33622..9ec8a285ba 100644 --- a/cpp/nvtabular/inference/categorify.cc +++ b/cpp/nvtabular/inference/categorify.cc @@ -93,6 +93,9 @@ namespace nvtabular case 'u': switch (dtype.itemsize()) { + case 1: + insert_int_mapping(values); + return; case 2: insert_int_mapping(values); return; @@ -107,6 +110,9 @@ namespace nvtabular case 'i': switch (dtype.itemsize()) { + case 1: + insert_int_mapping(values); + return; case 2: insert_int_mapping(values); return; @@ -204,6 +210,8 @@ namespace nvtabular case 'u': switch (itemsize) { + case 1: + return transform_int(input); case 2: return transform_int(input); case 4: @@ -215,6 +223,8 @@ namespace nvtabular case 'i': switch (itemsize) { + case 1: + return transform_int(input); case 2: return transform_int(input); case 4: diff --git a/tests/unit/ops/test_categorify.py b/tests/unit/ops/test_categorify.py index 2c90488b2b..80a52c06a3 100644 --- a/tests/unit/ops/test_categorify.py +++ b/tests/unit/ops/test_categorify.py @@ -704,9 +704,11 @@ def test_categorify_inference(): "unicode_string": np.random.randint( low=a_char, high=z_char, size=num_rows * 10, dtype="int32" ).view("U10"), + "int8_feature": np.random.randint(0, 10, dtype="int8", size=num_rows), "int16_feature": np.random.randint(0, 10, dtype="int16", size=num_rows), "int32_feature": np.random.randint(0, 10, dtype="int32", size=num_rows), "int64_feature": np.random.randint(0, 10, dtype="int64", size=num_rows), + "uint8_feature": np.random.randint(0, 10, dtype="uint8", size=num_rows), "uint16_feature": np.random.randint(0, 10, dtype="uint16", size=num_rows), "uint32_feature": np.random.randint(0, 10, dtype="uint32", size=num_rows), "uint64_feature": np.random.randint(0, 10, dtype="uint64", size=num_rows),