Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
fix acc topk's handling of the case when dim=0, fix tests as well (py…
Browse files Browse the repository at this point in the history
…torch#64727)

Summary:
Pull Request resolved: pytorch#64727

the acc ops convertor for topk has a subtle bug (i found this while trying to introduce max/min)
the code does not differentiate between dim == None and dim ==0, but these are both different computations

Reviewed By: jfix71, 842974287

Differential Revision: D30833621

fbshipit-source-id: 6cd84e6ca4e95bb1a6d6465e61830b76808a9c78
  • Loading branch information
emad authored and facebook-github-bot committed Sep 9, 2021
1 parent 3d3ff4a commit 46c886e
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def acc_ops_topk(network, target, args, kwargs, name):

num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
k = kwargs["k"]
dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims
dim = (kwargs["dim"] if kwargs["dim"] is not None else -1) % num_dims
operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN
layer = network.add_topk(
input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)
Expand Down

0 comments on commit 46c886e

Please sign in to comment.