Skip to content

Commit

Permalink
Add support for int8 values to Categorify inference (#1818)
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed May 11, 2023
1 parent feaa418 commit 55fe934
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cpp/nvtabular/inference/categorify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ namespace nvtabular
case 'u':
switch (dtype.itemsize())
{
case 1:
insert_int_mapping<uint8_t>(values);
return;
case 2:
insert_int_mapping<uint16_t>(values);
return;
Expand All @@ -107,6 +110,9 @@ namespace nvtabular
case 'i':
switch (dtype.itemsize())
{
case 1:
insert_int_mapping<int8_t>(values);
return;
case 2:
insert_int_mapping<int16_t>(values);
return;
Expand Down Expand Up @@ -204,6 +210,8 @@ namespace nvtabular
case 'u':
switch (itemsize)
{
case 1:
return transform_int<uint8_t>(input);
case 2:
return transform_int<uint16_t>(input);
case 4:
Expand All @@ -215,6 +223,8 @@ namespace nvtabular
case 'i':
switch (itemsize)
{
case 1:
return transform_int<int8_t>(input);
case 2:
return transform_int<int16_t>(input);
case 4:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/ops/test_categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,9 +704,11 @@ def test_categorify_inference():
"unicode_string": np.random.randint(
low=a_char, high=z_char, size=num_rows * 10, dtype="int32"
).view("U10"),
"int8_feature": np.random.randint(0, 10, dtype="int8", size=num_rows),
"int16_feature": np.random.randint(0, 10, dtype="int16", size=num_rows),
"int32_feature": np.random.randint(0, 10, dtype="int32", size=num_rows),
"int64_feature": np.random.randint(0, 10, dtype="int64", size=num_rows),
"uint8_feature": np.random.randint(0, 10, dtype="uint8", size=num_rows),
"uint16_feature": np.random.randint(0, 10, dtype="uint16", size=num_rows),
"uint32_feature": np.random.randint(0, 10, dtype="uint32", size=num_rows),
"uint64_feature": np.random.randint(0, 10, dtype="uint64", size=num_rows),
Expand Down

0 comments on commit 55fe934

Please sign in to comment.