Skip to content

Commit

Permalink
cpp experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed May 30, 2023
1 parent aa98b37 commit 2df007c
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions cpp/nvtabular/inference/categorify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace nvtabular

if ((dtype.kind() == 'O') || (dtype.kind() == 'U'))
{
int64_t i = 0;
int64_t i = UNIQUE_OFFSET;
for (auto &value : values)
{
if (!py::cast<bool>(isnull(value)))
Expand Down Expand Up @@ -138,20 +138,29 @@ namespace nvtabular
size_t size = values.size();
for (size_t i = 0; i < size; ++i)
{
mapping_int[static_cast<int64_t>(data[i])] = i;
mapping_int[static_cast<int64_t>(data[i])] = i + UNIQUE_OFFSET;
}
}

template <typename T>
py::array transform_int(py::array_t<T> input) const
{
py::object pandas = py::module_::import("pandas");
py::object isnull = pandas.attr("isnull");
py::array_t<int64_t> output(input.size());
const T *input_data = input.data();
int64_t *output_data = output.mutable_data();
for (int64_t i = 0; i < input.size(); ++i)
{
auto it = mapping_int.find(static_cast<int64_t>(input_data[i]));
output_data[i] = it == mapping_int.end() ? 0 : it->second;
if it == mapping_int.end()
{
output_data[i] = py::cast<bool>(isnull(input_data[i])) ? NULL_INDEX : OOV_INDEX;
}
else
{
output_data[i] = it->second;
}
}
return output;
}
Expand All @@ -169,18 +178,18 @@ namespace nvtabular
{
if (value.is_none())
{
data[i] = 0;
data[i] = NULL_INDEX;
}
else if (PyUnicode_Check(value.ptr()) || PyBytes_Check(value.ptr()))
{
std::string key = py::cast<std::string>(value);
auto it = mapping_str.find(key);
data[i] = it == mapping_str.end() ? 0 : it->second;
data[i] = it == mapping_str.end() ? OOV_INDEX : it->second;
}
else if (PyBool_Check(value.ptr()))
{
auto it = mapping_int.find(value.ptr() == Py_True);
data[i] = it == mapping_int.end() ? 0 : it->second;
data[i] = it == mapping_int.end() ? OOV_INDEX : it->second;
}
else
{
Expand Down Expand Up @@ -247,6 +256,11 @@ namespace nvtabular

std::unordered_map<std::string, int64_t> mapping_str;
std::unordered_map<int64_t, int64_t> mapping_int;

// TODO: Handle multiple OOV buckets?
const int64_t NULL_INDEX = 1;
const int64_t OOV_INDEX = 2;
const int64_t UNIQUE_OFFSET = 3;
};

// Reads in a parquet category mapping file in cpu memory using pandas
Expand Down

0 comments on commit 2df007c

Please sign in to comment.