Skip to content

Commit

Permalink
[Fixbug] Resolve the min/max function according to compute capability (
Browse files Browse the repository at this point in the history
…#112)

* resolve the min/max function according to compute capability

* .

* .
  • Loading branch information
yaoyaoding committed Feb 18, 2023
1 parent 70ab992 commit 4ab4cea
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 306 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
tests:
if: github.repository == 'hidet-org/hidet'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [self-hosted, Linux, X64, gpu]
container:
Expand Down
8 changes: 5 additions & 3 deletions python/hidet/cuda/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
set_device(self.prev_device_id)


@lru_cache(maxsize=1)
@lru_cache(maxsize=None)
def available() -> bool:
"""
Returns True if CUDA is available, False otherwise.
Expand All @@ -42,7 +42,7 @@ def available() -> bool:
return device_count() > 0


@lru_cache(maxsize=1)
@lru_cache(maxsize=None)
def device_count() -> int:
"""
Get the number of available CUDA devices.
Expand Down Expand Up @@ -122,14 +122,15 @@ def device(device_id: int):
return CudaDeviceContext(device_id)


@lru_cache(maxsize=None)
def compute_capability(device_id: int = 0) -> Tuple[int, int]:
"""
Get the compute capability of a CUDA device.
Parameters
----------
device_id: int
The ID of the device.
The ID of the device to query.
Returns
-------
Expand Down Expand Up @@ -178,3 +179,4 @@ def profiler_stop():
if available():
for i in range(device_count()):
properties(i)
compute_capability(i)
63 changes: 53 additions & 10 deletions python/hidet/ir/primitives/cuda/math/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from hidet.ir.expr import Expr, Call, ExprInt64, ExprFloat16, ExprInt16
from hidet.ir.type import FuncType
from hidet.ir.type import FuncType, DataType
from hidet.ir.func import Function
from hidet.ir.dtypes import int16
from hidet.ir.dtypes import int16, float16, float32, int64
from hidet.ir.primitives.func import register_primitive_function, primitive_func_pool, call_primitive_func
from hidet.ir.primitives.math import MathFunctionSet, register_math_function_set
from hidet.utils import initialize


@initialize()
def register_float16_primitives():
from hidet.ir.dtypes import int64, float16

register_primitive_function('cuda_i64_to_f16', FuncType([int64], float16), codegen_name='__ll2half_rn')


Expand Down Expand Up @@ -71,8 +70,8 @@ def register(self):
'round': ['hrint', 1],
'ceil': ['hceil', 1],
'floor': ['hfloor', 1],
'min': ['__hmin', 2],
'max': ['__hmax', 2],
'min_sm80': ['__hmin', 2],
'max_sm80': ['__hmax', 2],
'fma': ['__hfma', 3],
}

Expand All @@ -83,6 +82,8 @@ def register(self):
func_or_type=FuncType(param_types=['float16'] * num_args, ret_type='float16'),
)

self.register_via_delegate('min', float16, float32, min, 2)
self.register_via_delegate('max', float16, float32, max, 2)
self._register_tanh()

def _register_tanh(self):
Expand All @@ -102,6 +103,40 @@ def cuda_f16_tanh(x: f16) -> f16:

register_primitive_function(name='cuda_f16_tanh', func_or_type=cuda_f16_tanh)

def register_via_delegate(
self, name: str, target_type: DataType, delegate_type: DataType, delegate: Callable, num_args: int
):
from hidet.lang import script, cast, attr

if num_args == 1:

@script
def delegated_primitive(v: target_type) -> target_type:
attr.func_name = 'cuda_f16_{}'.format(name)
return cast(delegate(cast(v, delegate_type)), target_type)

elif num_args == 2:

@script
def delegated_primitive(a: target_type, b: target_type) -> target_type:
attr.func_name = 'cuda_f16_{}'.format(name)
return cast(delegate(cast(a, delegate_type), cast(b, delegate_type)), target_type)

elif num_args == 3:

@script
def delegated_primitive(a: target_type, b: target_type, c: target_type) -> target_type:
attr.func_name = 'cuda_f16_{}'.format(name)
return cast(
delegate(cast(a, delegate_type), cast(b, delegate_type), cast(c, delegate_type)), target_type
)

else:
raise ValueError('Unsupported num_args: {}'.format(num_args))

assert isinstance(delegated_primitive, Function)
register_primitive_function(name='cuda_f16_{}'.format(name), func_or_type=delegated_primitive)

def call(self, name: str, *args) -> Expr:
entry = primitive_func_pool.lookup_by_name(name)
return Call(entry.var, args)
Expand All @@ -121,7 +156,6 @@ def exp(self, a: Expr) -> Expr:
def erf(self, a: Expr) -> Expr:
# use float32 erf to delegate the float16 erf
from hidet.ir.expr import cast
from hidet.ir.dtypes import float32, float16
from hidet.ir.primitives.math import erf

return cast(erf(cast(a, float32)), float16)
Expand All @@ -145,15 +179,24 @@ def floor(self, a: Expr) -> Expr:
return self.call('cuda_f16_floor', a)

def min(self, a: Expr, b: Expr) -> Expr:
return self.call('cuda_f16_min', a, b)
from hidet.cuda import compute_capability

if compute_capability() >= (8, 0):
return self.call('cuda_f16_min_sm80', a, b)
else:
return self.call('cuda_f16_min', a, b)

def max(self, a: Expr, b: Expr) -> Expr:
return self.call('cuda_f16_max', a, b)
from hidet.cuda import compute_capability

if compute_capability() >= (8, 0):
return self.call('cuda_f16_max_sm80', a, b)
else:
return self.call('cuda_f16_max', a, b)

def pow(self, a: Expr, b: Expr) -> Expr:
# use float32 pow to delegate the float16 pow
from hidet.ir.expr import cast
from hidet.ir.dtypes import float32, float16
from hidet.ir.primitives.math import pow

a = cast(a, float32)
Expand Down
76 changes: 0 additions & 76 deletions python/hidet/ir/primitives/cuda/math/float32.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,82 +169,6 @@ def atan2(self, a: Expr, b: Expr) -> Expr:
def fma(self, a: Expr, b: Expr, c: Expr) -> Expr:
return self.call('fma', a, b, c)

# # pylint: disable=abstract-method
# def register(self):
# entries = {
# 'sin': ['sinf', 1],
# 'cos': ['cosf', 1],
# 'tanh': ['tanhf', 1],
# 'exp': ['expf', 1],
# 'erf': ['erff', 1],
# 'sqrt': ['sqrtf', 1],
# 'rsqrt': ['rsqrtf', 1],
# 'log': ['logf', 1],
# 'round': ['roundf', 1],
# 'ceil': ['ceilf', 1],
# 'floor': ['floorf', 1],
# 'min': ['fminf', 2],
# 'max': ['fmaxf', 2],
# 'pow': ['powf', 2],
# 'fma': ['fmaf', 3],
# }
#
# for name, (codegen_name, num_args) in entries.items():
# register_primitive_function(
# name='cuda_f32_{}'.format(name),
# codegen_name=codegen_name,
# func_or_type=FuncType(param_types=['float32'] * num_args, ret_type='float32'),
# )
#
# def call(self, name: str, *args) -> Expr:
# entry = primitive_func_pool.lookup_by_name(name)
# return Call(entry.var, args)
#
# def sin(self, a: Expr) -> Expr:
# return self.call('cuda_f32_sin', a)
#
# def cos(self, a: Expr) -> Expr:
# return self.call('cuda_f32_cos', a)
#
# def tanh(self, a: Expr) -> Expr:
# return self.call('cuda_f32_tanh', a)
#
# def exp(self, a: Expr) -> Expr:
# return self.call('cuda_f32_exp', a)
#
# def erf(self, a: Expr) -> Expr:
# return self.call('cuda_f32_erf', a)
#
# def sqrt(self, a: Expr) -> Expr:
# return self.call('cuda_f32_sqrt', a)
#
# def rsqrt(self, a: Expr) -> Expr:
# return self.call('cuda_f32_rsqrt', a)
#
# def log(self, a: Expr) -> Expr:
# return self.call('cuda_f32_log', a)
#
# def round(self, a: Expr) -> Expr:
# return self.call('cuda_f32_round', a)
#
# def ceil(self, a: Expr) -> Expr:
# return self.call('cuda_f32_ceil', a)
#
# def floor(self, a: Expr) -> Expr:
# return self.call('cuda_f32_floor', a)
#
# def min(self, a: Expr, b: Expr) -> Expr:
# return self.call('cuda_f32_min', a, b)
#
# def max(self, a: Expr, b: Expr) -> Expr:
# return self.call('cuda_f32_max', a, b)
#
# def pow(self, a: Expr, b: Expr) -> Expr:
# return self.call('cuda_f32_pow', a, b)
#
# def fma(self, a: Expr, b: Expr, c: Expr) -> Expr:
# return self.call('cuda_f32_fma', a, b, c)


cuda_f32_math_function_set = CUDAFloat32MathFunctionSet()
cuda_f32_math_function_set.register()
Expand Down
Loading

0 comments on commit 4ab4cea

Please sign in to comment.