Skip to content

Commit

Permalink
fix: type error in embedding_bag (#2418)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Oct 30, 2023
1 parent 3e612c1 commit 27a9f6d
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def embedding_bag(

# TODO: support 2D inputs
# indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))

reduce_name = ""
if mode == 0: # sum
reduce_op = functools.partial(
impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir
Expand Down Expand Up @@ -143,7 +143,6 @@ def embedding_bag(
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
# is equal to the number of bags + 1. The last element is the size of the input,
# or the ending index position of the last bag (sequence).

offsets[-1] = indices.shape[0]

# separately reduce embeddings for different bags
Expand All @@ -158,8 +157,8 @@ def embedding_bag(
f"{name}_slice_embed_{i}",
embed,
0,
offsets[i],
offsets[i + 1],
int(offsets[i]),
int(offsets[i + 1]),
1,
)
reduced_sliced_embed = reduce_op(
Expand Down

0 comments on commit 27a9f6d

Please sign in to comment.