diff --git a/python/hidet/graph/frontend/torch/interpreter.py b/python/hidet/graph/frontend/torch/interpreter.py index fe79723a2..53f369efd 100644 --- a/python/hidet/graph/frontend/torch/interpreter.py +++ b/python/hidet/graph/frontend/torch/interpreter.py @@ -127,8 +127,17 @@ def _lookup_hidet_method(self, torch_method) -> Callable: def _lookup_hidet_function(self, torch_func) -> Optional[OverloadedFunction]: if torch_func not in Registry.registered_functions: name = self._get_callable_name(torch_func) + from hidet.graph.ops import cast + pattern2func = { - '_dynamo_get_item_lambda': OverloadedFunction.from_lambda(lambda target, index: target[index]) + '_dynamo_get_item_lambda': OverloadedFunction.from_lambda(lambda target, index: target[index]), + # Turns out the wrapped ops in issue #358 are some `numpy_method_wrapper` and `numpy_operator_wrapper`. + # According to the class definition in pytorch/torch/_dynamo/utils.py(line 2461-2497), + # they're just functionally equivalent to the original numpy functions. + '>': OverloadedFunction.from_lambda(lambda x, y: x >= y), + '>': OverloadedFunction.from_lambda( + lambda x, dtype: cast(x, data_type(dtype)) + ), } for pattern, func in pattern2func.items(): if pattern in name: