From ea32c5cf0e95a847b9385f5558404f6a7026cbad Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Thu, 28 Sep 2023 00:26:33 -0400 Subject: [PATCH] [Docs] Update the documentation for the coming release (#360) --- docs/requirements.txt | 2 +- docs/source/_static/custom.css | 58 +++-- docs/source/conf.py | 4 +- .../developer-guides/hidet-script/index.rst | 8 - .../getting-started/build-from-source.rst | 2 +- docs/source/hidet-script/examples/index.rst | 18 ++ docs/source/hidet-script/index.rst | 11 + .../hidet-script/reference/1-type-system.rst | 198 ++++++++++++++++++ .../hidet-script/reference/2-expression.rst | 83 ++++++++ .../hidet-script/reference/3-statement.rst | 83 ++++++++ .../hidet-script/reference/4-function.rst | 17 ++ .../hidet-script/reference/5-module.rst | 32 +++ .../reference/6-cuda-specific.rst | 55 +++++ .../hidet-script/reference/7-cpu-specific.rst | 21 ++ docs/source/hidet-script/reference/index.rst | 22 ++ .../how-to-guides/add-new-operator/index.rst | 6 +- docs/source/index.rst | 17 +- docs/source/python_api/index.rst | 2 - docs/source/python_api/ir/compute.rst | 4 +- .../add-new-operator-compute-definition.py | 8 +- .../add-new-operator-rule-based.py | 4 +- .../add-new-operator-template-based.py | 41 +--- .../add-operator-resolve-rule.py | 0 .../add-subgraph-rewrite-rule.py | 17 +- .../hidet-script-dynamic-kernel.py | 27 +-- gallery/getting-started/quick-start.py | 38 +++- gallery/hidet-script/0-hello-world.py | 67 ++++++ gallery/hidet-script/1-scalar-addition.py | 42 ++++ gallery/hidet-script/2-vector-addition.py | 43 ++++ gallery/hidet-script/3-kernel-functions.py | 117 +++++++++++ gallery/hidet-script/4-naive-matmul.py | 66 ++++++ gallery/hidet-script/5-efficient-matmul.py | 164 +++++++++++++++ gallery/hidet-script/README.rst | 2 + gallery/how-to-guides/visualize-flow-graph.py | 4 +- ...n-onnx-model.py => optimize-onnx-model.py} | 21 +- gallery/tutorials/optimize-pytorch-model.py | 6 +- python/hidet/graph/transforms/base.py | 1 + scripts/lint/format.sh | 2 +- 38 files changed, 1169 insertions(+), 144 deletions(-) delete mode 100644 docs/source/developer-guides/hidet-script/index.rst create mode 100644 docs/source/hidet-script/examples/index.rst create mode 100644 docs/source/hidet-script/index.rst create mode 100644 docs/source/hidet-script/reference/1-type-system.rst create mode 100644 docs/source/hidet-script/reference/2-expression.rst create mode 100644 docs/source/hidet-script/reference/3-statement.rst create mode 100644 docs/source/hidet-script/reference/4-function.rst create mode 100644 docs/source/hidet-script/reference/5-module.rst create mode 100644 docs/source/hidet-script/reference/6-cuda-specific.rst create mode 100644 docs/source/hidet-script/reference/7-cpu-specific.rst create mode 100644 docs/source/hidet-script/reference/index.rst rename gallery/{how-to-guides => developer-guides}/add-new-operator-compute-definition.py (98%) rename gallery/{how-to-guides => developer-guides}/add-new-operator-rule-based.py (97%) rename gallery/{how-to-guides => developer-guides}/add-new-operator-template-based.py (88%) rename gallery/{how-to-guides => developer-guides}/add-operator-resolve-rule.py (100%) rename gallery/{how-to-guides => developer-guides}/add-subgraph-rewrite-rule.py (95%) create mode 100644 gallery/hidet-script/0-hello-world.py create mode 100644 gallery/hidet-script/1-scalar-addition.py create mode 100644 gallery/hidet-script/2-vector-addition.py create mode 100644 gallery/hidet-script/3-kernel-functions.py create mode 100644 gallery/hidet-script/4-naive-matmul.py create mode 100644 gallery/hidet-script/5-efficient-matmul.py create mode 100644 gallery/hidet-script/README.rst rename gallery/tutorials/{run-onnx-model.py => optimize-onnx-model.py} (92%) diff --git a/docs/requirements.txt b/docs/requirements.txt index 54fb23bde..33654f268 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,7 +3,7 @@ sphinx sphinx-gallery sphinx-copybutton autodocsumm -sphinx-book-theme +sphinx-book-theme==1.0.1 matplotlib sphinxcontrib-bibtex sphinxcontrib-googleanalytics diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css index 08c1946c6..37f631bcc 100644 --- a/docs/source/_static/custom.css +++ b/docs/source/_static/custom.css @@ -17,13 +17,13 @@ div.sphx-glr-download-python { /* font-weight: 550;*/ /*}*/ -div.container-xl { - max-width: 1500px; -} +/*div.container-xl {*/ +/* max-width: 1500px;*/ +/*}*/ -.sphx-glr-script-out .highlight pre { - background-color: #e9ecef !important; -} +/*.sphx-glr-script-out .highlight pre {*/ +/* background-color: #e9ecef !important;*/ +/*}*/ div.sphx-glr-download a { background-color: #0084c845 !important; @@ -35,33 +35,49 @@ div.sphx-glr-download a:hover { background-image: none !important; } -dt.sig { - letter-spacing: 0; - font-family: Menlo,Monaco,Consolas,SFMono-Regular,"Liberation Mono","Courier New",monospace; - /*font-family: ;*/ - /*font-size=1em;*/ -} +/*dt.sig {*/ +/* letter-spacing: 0;*/ +/* font-family: Menlo,Monaco,Consolas,SFMono-Regular,"Liberation Mono","Courier New",monospace;*/ +/* !*font-family: ;*!*/ +/* !*font-size=1em;*!*/ +/*}*/ -span.sig-name { - color: #0558b7; -} +/*span.sig-name {*/ +/* color: #0558b7;*/ +/*}*/ em.sig-param { font-style: normal; } em.property { font-style: normal; } -code { - font-family: monospace; +a { + text-decoration: unset; } -code { - color: #434552; + +.navbar-brand img { + width: 210px; } -.heading-style, h1, h2, h3, h4, h5, h6 { - font-family: Ubuntu, system-ui; +.navbar-brand { + padding: 1.8rem 0; } +/*code {*/ +/* font-family: monospace;*/ +/*}*/ +/*code {*/ +/* color: #434552;*/ +/*}*/ + +/*.heading-style, h1, h2, h3, h4, h5, h6 {*/ +/* font-family: Ubuntu, system-ui;*/ +/*}*/ dl.class, dl.function { margin-bottom: 3em; +} + +.bd-sidebar-primary .sidebar-primary-items__end { + margin-bottom: 0; + margin-top: 0 } \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index c913516d1..978d9eb54 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -114,8 +114,8 @@ html_theme_options = { "repository_url": "https://github.com/hidet-org/hidet", "use_repository_button": True, - 'logo_only': True, - "extra_navbar": r"Customized Netron", + # 'logo_only': True, + # "extra_navbar": r"Customized Netron", "show_navbar_depth": 1, # "home_page_in_toc": True } diff --git a/docs/source/developer-guides/hidet-script/index.rst b/docs/source/developer-guides/hidet-script/index.rst deleted file mode 100644 index f997da96b..000000000 --- a/docs/source/developer-guides/hidet-script/index.rst +++ /dev/null @@ -1,8 +0,0 @@ -Hidet Script -============ - - -.. toctree:: - - /gallery/developer-guides/hidet-script-dynamic-kernel - diff --git a/docs/source/getting-started/build-from-source.rst b/docs/source/getting-started/build-from-source.rst index 8f5a6529c..4b35652d8 100644 --- a/docs/source/getting-started/build-from-source.rst +++ b/docs/source/getting-started/build-from-source.rst @@ -33,7 +33,7 @@ shared library: After building, you could find two libraries ``libhidet.so`` and ``libhidet_runtime.so`` under ``build/lib`` directory. Install the Hidet Python package -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Next we will install the Python package of Hidet in the develop mode via pip: diff --git a/docs/source/hidet-script/examples/index.rst b/docs/source/hidet-script/examples/index.rst new file mode 100644 index 000000000..8ef73cd85 --- /dev/null +++ b/docs/source/hidet-script/examples/index.rst @@ -0,0 +1,18 @@ +Examples +======== + +This section contains a collection of examples that demonstrate how to use Hidet Script to write kernel programs. Each +example is a self-contained hidet script program that can be run directly. + +.. _hidet script examples: + +.. toctree:: + :maxdepth: 1 + :caption: Hidet Script Examples + + ../../gallery/hidet-script/0-hello-world + ../../gallery/hidet-script/1-scalar-addition + ../../gallery/hidet-script/2-vector-addition + ../../gallery/hidet-script/3-kernel-functions + ../../gallery/hidet-script/4-naive-matmul + ../../gallery/hidet-script/5-efficient-matmul diff --git a/docs/source/hidet-script/index.rst b/docs/source/hidet-script/index.rst new file mode 100644 index 000000000..8bea5bce0 --- /dev/null +++ b/docs/source/hidet-script/index.rst @@ -0,0 +1,11 @@ +Introduction +============ + +Hidet Script is a domain specific language (DSL) for writing tensor programs directly in python. +The users can write the tensor programs with python's syntax with some constrains and extensions. +A transpiler is used to translate the python abstract syntax tree (AST) to Hidet's tensor program IR. +Then, the translated tensor programs in Hidet IR is optimized and compiled to the target binary for execution. +The tensor program writer works in the python environment in the whole process. + + +To get started, please refer to the :ref:`hidet script examples`. diff --git a/docs/source/hidet-script/reference/1-type-system.rst b/docs/source/hidet-script/reference/1-type-system.rst new file mode 100644 index 000000000..fb8af1f42 --- /dev/null +++ b/docs/source/hidet-script/reference/1-type-system.rst @@ -0,0 +1,198 @@ +Type System +=========== + +In hidet script, we have a type system that contains scalar types, tensor type, as well as pointer types. + +Scalar types +------------ + +Hidet supports the following scalar types: +- integer types: ``i8``, ``i16``, ``i32``, ``i64`` (``int8``, ``int16``, ``int32``, ``int64``) +- floating point types: ``f16``, ``f32``, ``f64``, ``bf16``, ``tf32`` (``float16``, ``float32``, ``float64``, ``bfloat16``, ``tfloat32``) +- boolean type: ``bool`` +- complex types: ``c64``, ``c128`` (``complex64``, ``complex128``) + +Some types have both short names and long names. For example, ``i8`` and ``int8`` are the same type. + +There are also vectorized scalar types: +- vectorized integer types: ``i8x4`` (``int8x4``) +- vectorized float types: ``f16x2``, ``f32x4`` (``float16x2``, ``float32x4``) + +Tensor type +----------- + +Hidet is designed to simplify the tensor program writing. Therefore, we have a powerful tensor type that +represents a tensor with a specific element data type, shape, and memory layout. More specifically, a +tensor type has the following attributes: +- ``dtype``: the data type of the tensor elements, can be any scalar type. +- ``shape``: a list of expressions that represents the shape of the tensor. +- ``layout``: the memory layout of the tensor. + +The following code snippet shows how to define a tensor type: + +.. code-block:: + + import hidet + from hidet.lang import attrs, printf + from hidet.lang.types import tensor, f32 + + with hidet.script_module() as script_module: + @hidet.script + def kernel(): + attrs.func_kind = 'cpu_kernel' + + # by default, the layout is a row-major layout + a = tensor(dtype=f32, shape=[1024, 1024]) + + a[0, 0] = 0.0 + + printf("a[%d, %d] = %.1f\n", 0, 0, a[0, 0]) + + module = script_module.build() + module() + + +Tensor shape +~~~~~~~~~~~~ + +The shape of the tensor must be determined at the compile time. Therefore, the shape of the tensor can only +be defined with constant expressions. If we want to access a tensor with shape determined at runtime with +variable expressions, we can use *tensor pointer* (will be discussed later). + + +Tensor layout +~~~~~~~~~~~~~ + +The layout of a tensor defines how to map the coordinates of a tensor element to the linear position of the +element in the memory space. Generally speaking, a layout maps a :math:`n`-dimensional coordinate +:math:`(c_0, c_1, \dots, c_{n-1})` to a linear index: + +.. math:: + + index = layout(c_0, c_1, ..., c_{n-1}) + + +The most commonly used layout is the row-major layout. In row-major layout, the linear index is calculated as: + + +.. math:: + + index = c_0 \times s_1 \times s_2 \times \dots \times s_{n-1} + c_1 \times s_2 \times \dots \times s_{n-1} + \dots + c_{n-2} \times s_{n-1} + c_{n-1} + +where :math:`s_i` is the size of the :math:`i`-th dimension of the tensor: :math:`shape=(s_0, s_1, \dots, s_{n-1})`. + + +Similar to the row-major layout, we can also define a column-major layout as follows: + +.. math:: + + index = c_{n-1} \times s_{n-2} \times \dots \times s_1 \times s_0 + c_{n-2} \times \dots \times s_1 \times s_0 + \dots + c_1 \times s_0 + c_0 + +The row-major layout is the default layout if we do not specify the layout of a tensor. We can also specify +the layout of a tensor with the ``layout`` argument of the ``tensor`` type. For example, we can define a tensor with +column-major layout as follows: + +.. code-block:: + + from hidet.lang.layout import column_major + from hidet.lang.types import tensor, f32 + # ... + a = tensor(dtype=f32, shape=[1024, 1024], layout=column_major(1024, 1024)) + # or ignore shape if the layout is specified + b = tensor(dtype=f32, layout=column_major(1024, 1024)) + + +Both row-major layout and column-major layout are special cases of the strided layout. +In hidet, we can define a strided layout like + + +.. code-block:: + + from hidet.lang.layout import strided_layout + from hidet.lang.types import tensor, f32 + + # equivalent to row-major layout + a = tensor(dtype=f32, layout=strided_layout(shape=[1024, 1024], ranks=[0, 1])) + # equivalent to column-major layout + b = tensor(dtype=f32, layout=strided_layout(shape=[1024, 1024], ranks=[1, 0])) + # the ranks define the order of the dimensions from the one that changes the slowest to the one that changes the fastest + c = tensor(dtype=f32, layout=strided_layout(shape=[2, 2, 2], ranks=[0, 2, 1])) + # c[coordinate] -> index + # c[0, 0, 0] -> 0 + # c[0, 1, 0] -> 1 + # c[0, 0, 1] -> 2 + # c[0, 1, 1] -> 3 + # c[1, 0, 0] -> 4 + # c[1, 1, 0] -> 5 + # c[1, 0, 1] -> 6 + # c[1, 1, 1] -> 7 + +Given two layouts $f$ and $g$, we can define a new layout $h$ as the composition of $f$ and $g$ with $f$ as the outer +layout and $g$ as the inner layout: + +.. math:: + + h(\textbf{c}) = f(\textbf{c}/\textbf{s}_{g}) * n_g + g(\textbf{c} \mod \textbf{s}_{g}) + +where :math:`\textbf{c}` is the coordinate of the tensor element, :math:`\textbf{s}_{g}` is the shape of the inner +layout :math:`g`, and :math:`n_g` is the number of elements in the inner layout :math:`g`. The division and modulo +operations are performed element-wise. The composed layout $h$ has the same number of dimensions as the outer and inner +layouts, and the shape of the composed layout is the elementwise product of the shapes of the outer and inner layouts. + +In hidet script, we can use the *multiply* operator ``*`` to compose two layouts. For example, we can define a +composed layout as follows: + +.. code-block:: + + from hidet.lang.layout import row_major, column_major + + c = row_major(2, 1) * row_major(2, 2) + # c shape=[4, 2] + # c[0, 0] -> 0 + # c[0, 1] -> 1 + # c[1, 0] -> 2 + # c[1, 1] -> 3 + # c[2, 0] -> 4 + # c[2, 1] -> 5 + # c[3, 0] -> 6 + # c[3, 1] -> 7 + + d = row_major(2, 1) * column_major(2, 2) + # d shape=[4, 2] + # d[0, 0] -> 0 + # d[1, 0] -> 1 + # d[0, 1] -> 2 + # d[1, 1] -> 3 + # d[2, 0] -> 4 + # d[3, 0] -> 5 + # d[2, 1] -> 6 + # d[3, 1] -> 7 + +We can apply the composition operation multiple times to compose multiple layouts. For example, + +.. code-block:: + + from hidet.lang.layout import row_major, column_major + + e = row_major(2, 1) * row_major(2, 2) * column_major(2, 2) # e shape=[8, 4] + +The composition operation is associative, i.e., :math:`(f * g) * h = f * (g * h)`, but not commutative, i.e., +it is highly likely :math:`f * g \neq g * f`. + + +Pointer types +~~~~~~~~~~~~~ + +In hidet, we can define a pointer type with the same semantics as the pointer type in C/C++. + +To construct a pointer type, we use the ``~`` operator to apply to a scalar type or pointer type: + +- ``~i32``: a pointer to ``i32`` type +- ``~(~f16)``: a pointer to a pointer to ``f16`` type + + +Void type +~~~~~~~~~ + +The ``void`` type can be used as the return type of a function, or used to define a ``void`` pointer type +(i.e., ``~void``). diff --git a/docs/source/hidet-script/reference/2-expression.rst b/docs/source/hidet-script/reference/2-expression.rst new file mode 100644 index 000000000..76dffdba5 --- /dev/null +++ b/docs/source/hidet-script/reference/2-expression.rst @@ -0,0 +1,83 @@ +Expressions +=========== + +Hidet script supports the following expressions: + +Literals +-------- + +The literal expressions are the expressions that represent constant values. An integer literal (e.g., ``1``) +has data type ``i32`` by default. A floating point literal (e.g., ``1.0``) has data type ``float32`` by default. +A boolean literal (i.e., ``True`` and ``False``) has data type ``bool`` by default. To define a literal with a +specific data type, we can use the form ``()`` like ``f16(1.0)`` in the hidet script. + +Variables +--------- + +A variable is an expression that represents a memory location. A variable has a name and a data type. A +variable can be defined in 1) the function parameters, 2) the variable declaration statement, 3) the +for loop statement, 4) the for-mapping statement. + +Unary expressions +----------------- + +A unary expression is an expression that applies a unary operator to a single operand. The unary operators +supported in hidet script are: + +- ``+e``: unary plus +- ``-e``: unary minus +- ``~e``: get the address of ``e`` +- ``bitwise_not(e)``: bitwise not, where ``bitwise_not`` refers to ``hidet.lang.bitwise_not`` +- ``not cond``: logical not + +Binary expressions +------------------ + +A binary expression is an expression that applies a binary operator to two operands. The binary operators +supported in hidet script are: + +- ``e1 + e2``: addition +- ``e1 - e2``: subtraction +- ``e1 * e2``: multiplication +- ``e1 / e2``: division (we follow the semantics of c/c++ instead of python) +- ``e1 % e2``: remainder +- ``e1 ** e2``: power +- ``e1 << e2``: left shift +- ``e1 >> e2``: right shift +- ``e1 & e2``: bitwise and +- ``e1 | e2``: bitwise or +- ``e1 ^ e2``: bitwise xor +- ``e1 and e2``: logical and +- ``e1 or e2``: logical or +- ``e1 == e2``: equal +- ``e1 != e2``: not equal +- ``e1 < e2``: less than +- ``e1 <= e2``: less than or equal +- ``e1 > e2``: greater than +- ``e1 >= e2``: greater than or equal + + +Note on division: in python, the division operator ``/`` will produce a floating point result even if the +operands are integers. However, in hidet script, we follow the semantics of c/c++: if the operands are integers, +the division operator ``/`` will produce an integer result with floor(a / b) value; if the operands are floating +point numbers, the division operator ``/`` will produce a floating point result. + +Ternary expressions +------------------- + +A ternary expression is an expression that applies a ternary operator to three operands. The ternary operator +supported in hidet script is: + +- ``true_expr if cond else false_expr``: conditional expression + +This operator has the same semantics as the conditional expression in c/c++: ``cond ? true_expr : false_expr``. + +Subscript and slice expressions +------------------------------- + +For a tensor ``t`` or tensor pointer + +- ``e1[p1, p2, ..., pn]``: subscript expression +- ``e1[p1, p2:q2, p3:, :p4, :, p5]``: slice expression +- ``func(e1, e2, ..., en)``: function call expression +- ``address(e)``: get the address of ``e``, where ``address`` refers to ``hidet.lang.address`` diff --git a/docs/source/hidet-script/reference/3-statement.rst b/docs/source/hidet-script/reference/3-statement.rst new file mode 100644 index 000000000..926c2ce64 --- /dev/null +++ b/docs/source/hidet-script/reference/3-statement.rst @@ -0,0 +1,83 @@ +Statements +========== + +Control flow statements +----------------------- + +Hidet script supports the following control flow statements: + +- ``if`` statement +- ``for`` statement + - ``for-mapping`` statement + - ``while`` statement + +If statement +------------ + +The ``if`` statement has the same semantics as the ``if`` statement in c/c++ and python. + +For statement +------------- + +Hidet script supports the following kinds of for statements: + +.. code-block:: + + from hidet.lang import grid, printf + + for i in range(10): + printf("%d\n", i) + + for i, j in grid(10, 10): + printf("%d %d\n", i, j) + + for indices in grid(10, 10, bind_tuple=True): + printf("%d %d\n", indices[0], indices[1]) + + # explicitly set the attributes of the loop variables + # the attribute can be one of + # - 'p': parallelize this loop axis + # - 'u': unroll this loop axis + # a number can be appended to the attribute to specify how many threads to use or the unroll factor + # like 'p2u3' means parallelize loop axis 'i' with 2 threads and unroll loop axis 'j' with factor 3. + for i, j in grid(10, 10, attrs='pu'): + printf("%d %d\n", i, j) + + +For mapping statement +--------------------- + +Task mapping +~~~~~~~~~~~~ + +Please refer to the `Hidet paper `_ for the definition of task mapping + +.. todo:: + + add a brif introduction here to make it self-contained + +Iterate the task mapping +~~~~~~~~~~~~~~~~~~~~~~~~ + +The task mappings are defined in the ``hidet.lang.mapping`` module. To use the task mappings, we can import the module +and use the task mappings like the following: + +.. code-block:: + + from hidet.lang import printf, grid + from hidet.lang.mapping import spatial, repeat + + # iterate the spatial mapping + for w in grid(10, attrs='p'): + for i, j in spatial(2, 5).on(w): + printf("%d %d\n", i, j) + + for i, j in spatial(2, 5).repeat(3, 4).on(w): + # task mapping shape: (6, 20) + # num workers: 10 + printf("%d %d\n", i, j) + +While statement +--------------- + +Hidet also supports the ``while`` statement, and it has the same semantics as python and c/c++. diff --git a/docs/source/hidet-script/reference/4-function.rst b/docs/source/hidet-script/reference/4-function.rst new file mode 100644 index 000000000..72c4f3b95 --- /dev/null +++ b/docs/source/hidet-script/reference/4-function.rst @@ -0,0 +1,17 @@ +Function +======== + +Function kinds +-------------- + +A function can be one of the following kinds: + +- ``public``: a public function can be invoked in python directly +- ``cuda_kernel``: a cuda kernel function +- ``cuda_internal``: a cuda device function that can only be invoked by cuda kernel/device functions +- ``cpu_kernel``: a cpu kernel function +- ``cpu_internal``: a cpu function that will be used by other cpu functions + +Only the ``public`` functions will be exposed to python. For the modules that defines a kernel function +(i.e., ``cuda_kernel`` or ``cpu_kernel``), and there is not a ``public`` function named ``launch``, then hidet +will automatically create a ``public`` function named ``launch`` that will launch the kernel function. diff --git a/docs/source/hidet-script/reference/5-module.rst b/docs/source/hidet-script/reference/5-module.rst new file mode 100644 index 000000000..347e0e8e4 --- /dev/null +++ b/docs/source/hidet-script/reference/5-module.rst @@ -0,0 +1,32 @@ +Module +====== + +Script module +------------- + +A script module is a collections of hidet script functions and global variables. It serves as a compilation unit +of hidet. We can use ``hidet.script_module()`` to create a script module. The created script module can be used as +a python context manager like + +.. code-block:: + + import hidet + from hidet.lang import attrs + from hidet.lang.types import f32 + + with hidet.script_module() as script_module: + # define global variables like + script_module.define_global_var(name='global_var', var_type=f32) + ... + + # define functions like + @hidet.script + def foo(): + attrs.func_kind = 'public' # the function kind is mandatory + ... + + # we can define multiple functions in the script module and call each other + + # we can build the script module to get a CompiledModule (hidet.runtime.CompiledModule) + # that can be invoked in python directly + module = script_module.build() diff --git a/docs/source/hidet-script/reference/6-cuda-specific.rst b/docs/source/hidet-script/reference/6-cuda-specific.rst new file mode 100644 index 000000000..5be9182d7 --- /dev/null +++ b/docs/source/hidet-script/reference/6-cuda-specific.rst @@ -0,0 +1,55 @@ +CUDA Specifics +============== + +.. todo: + + make is more comprehensive and detailed + + +Function attributes +------------------- + +The ``cuda_kernel`` function kind has the following attributes: +- ``attrs.cuda.block_dim`` (required): the block dimensions +- ``attrs.cuda.grid_dim`` (required): the grid dimensions +- ``attrs.cuda.dynamic_smem_bytes`` (optional): the dynamic shared memory size to use +- ``attrs.cuda.min_blocks`` (optional): the minimum number of blocks this kernel will be launched. + +Memory scope +------------ + +To define a tensor that resides in the shared memory, we can specify the ``scope`` argument of +the ``hidet.lang.types.tensor`` constructor: + +.. code-block:: + + from hidet.lang.types import tensor, f32, DeclareScope + + # define a tensor in the shared memory + a = tensor(dtype=f32, shape=[10, 10], scope='shared') # use the string to specify the scope + b = tensor(dtype=f32, shape=[10, 10], scope=DeclareScope.Shared) # use the enum to specify the scope + + # similarly, we can define a tensor that resides in the register file + # please note that each thread will have a f32[10, 10] tensor + c = tensor(dtype=f32, shape=[10, 10], scope='register') + d = tensor(dtype=f32, shape=[10, 10], scope=DeclareScope.Register) + +Primitive functions +------------------- + +Hidet provides some primitive functions that can be used in the cuda kernel functions. The primitive functions +are defined in the ``hidet.lang.cuda`` module. The following table lists the commonly used primitive functions: + +.. todo:: + + make a full list in the reference section. + +- ``threadIdx``, ``blockIdx``, ``blockDim``, ``gridDim``: the thread index, block index, block dimension and grid dimension. +- ``syncthreads()``: synchronize all threads in the same block. +- ``ldmatrix(...)``: load a matrix from shared memory to the register file. +- ``mma_sync(...)``: perform matrix-matrix multiplication using the tensor cores. +- ``atomic_add(...)``: perform atomic add operation (other atomic functions like ``atomic_max`` are also included). +- ``shfl_sync(...)``: warp shuffle operation. +- ``dynamic_shared_memory(...)``: access the dynamic allocated shared memory + +Please refer to the ``hidet.lang.cuda`` module for the complete list of supported primitive functions diff --git a/docs/source/hidet-script/reference/7-cpu-specific.rst b/docs/source/hidet-script/reference/7-cpu-specific.rst new file mode 100644 index 000000000..d2b5de6a8 --- /dev/null +++ b/docs/source/hidet-script/reference/7-cpu-specific.rst @@ -0,0 +1,21 @@ +CPU Specifics +============= + +Primitive functions +------------------- + +Hidet provides primitives to use the avx instructions in modern cpu. They includes + +- ``avx_f32x4_load(...)``: vectorized load 4 f32 values from memory +- ``avx_f32x4_store(...)``: vectorized store 4 f32 values to memory +- ``avx_f32x4_fmadd(...)``: vectorized fused multiply-add operation +- ``avx_f32x4_setzero(...)``: get the zero initialized vector +- ``avx_f32x4_broadcast(...)``: broadcast a scalar to a vector + +There are also corresponding ``f32x8`` primitives. + +Multi-threading +--------------- + +Hidet relies on the OpenMP to support multi-threading. To use the multi-threading, please specify the +``p`` attribute of the ``hidet.lang.grid`` or ``hidet.lang.mapping.repeat`` functions. diff --git a/docs/source/hidet-script/reference/index.rst b/docs/source/hidet-script/reference/index.rst new file mode 100644 index 000000000..6570dccbf --- /dev/null +++ b/docs/source/hidet-script/reference/index.rst @@ -0,0 +1,22 @@ +Reference +========= + +As other programming languages, Hidet Script has its type system, expressions, statements, functions, and modules. + +Each module is a compilation unit, and it contains a collection of functions and global variables. Each function +executes a series of statements. The function can define variables and manipulate them. Each variable has its data +type. + +The details of the type system, expressions, statements, functions, and modules are described in the following sections. + +.. toctree:: + :maxdepth: 1 + :caption: Hidet Script Examples + + 1-type-system + 2-expression + 3-statement + 4-function + 5-module + 6-cuda-specific + 7-cpu-specific diff --git a/docs/source/how-to-guides/add-new-operator/index.rst b/docs/source/how-to-guides/add-new-operator/index.rst index 203f8095a..b7d738ebb 100644 --- a/docs/source/how-to-guides/add-new-operator/index.rst +++ b/docs/source/how-to-guides/add-new-operator/index.rst @@ -15,11 +15,11 @@ an operator. :maxdepth: 1 :caption: Define Computation - ../../gallery/how-to-guides/add-new-operator-compute-definition + ../../gallery/developer-guides/add-new-operator-compute-definition .. toctree:: :maxdepth: 1 :caption: Two Scheduling Methods - ../../gallery/how-to-guides/add-new-operator-rule-based - ../../gallery/how-to-guides/add-new-operator-template-based + ../../gallery/developer-guides/add-new-operator-rule-based + ../../gallery/developer-guides/add-new-operator-template-based diff --git a/docs/source/index.rst b/docs/source/index.rst index d20363d71..35ab05e24 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,29 +21,32 @@ Hidet is an open-source DNN inference framework, it features :caption: Tutorials gallery/tutorials/optimize-pytorch-model - gallery/tutorials/run-onnx-model + gallery/tutorials/optimize-onnx-model .. toctree:: :maxdepth: 1 - :caption: How-to Guide + :caption: Hidet Script + + hidet-script/index + hidet-script/examples/index + hidet-script/reference/index - how-to-guides/add-new-operator/index - gallery/how-to-guides/add-operator-resolve-rule - gallery/how-to-guides/add-subgraph-rewrite-rule - gallery/how-to-guides/visualize-flow-graph .. toctree:: :maxdepth: 1 :caption: Developer Guide + how-to-guides/add-new-operator/index + gallery/developer-guides/add-operator-resolve-rule + gallery/developer-guides/add-subgraph-rewrite-rule developer-guides/contributing.rst - developer-guides/hidet-script/index .. toctree:: :maxdepth: 1 :caption: Notes notes/operator-cache + gallery/how-to-guides/visualize-flow-graph .. toctree:: :maxdepth: 1 diff --git a/docs/source/python_api/index.rst b/docs/source/python_api/index.rst index c0cb41128..dc6d1425d 100644 --- a/docs/source/python_api/index.rst +++ b/docs/source/python_api/index.rst @@ -12,12 +12,10 @@ Python API root option - driver cuda tensor data_types ops/index - ir/index graph/index runtime/index utils/index diff --git a/docs/source/python_api/ir/compute.rst b/docs/source/python_api/ir/compute.rst index efbecc88a..41781a363 100644 --- a/docs/source/python_api/ir/compute.rst +++ b/docs/source/python_api/ir/compute.rst @@ -3,8 +3,8 @@ hidet.ir.compute .. tip:: - Please refer to :doc:`here ` for how to use these compute - primitives to define a computation task. + Please refer to :doc:`here ` for how to use these + compute primitives to define a computation task. .. automodule:: hidet.ir.compute diff --git a/gallery/how-to-guides/add-new-operator-compute-definition.py b/gallery/developer-guides/add-new-operator-compute-definition.py similarity index 98% rename from gallery/how-to-guides/add-new-operator-compute-definition.py rename to gallery/developer-guides/add-new-operator-compute-definition.py index 80972050e..8c22093f8 100644 --- a/gallery/how-to-guides/add-new-operator-compute-definition.py +++ b/gallery/developer-guides/add-new-operator-compute-definition.py @@ -344,9 +344,7 @@ def reduce_sum_example(): b = compute( 'b', shape=[4], - fcompute=lambda i: reduce( - shape=[3], fcompute=lambda j: a[i, j], reduce_type='sum' - ), + fcompute=lambda i: reduce(shape=[3], fcompute=lambda j: a[i, j], reduce_type='sum'), ) task = Task('reduce_sum', inputs=[a], outputs=[b]) run_task(task, [hidet.randn([4, 3])]) @@ -365,9 +363,7 @@ def arg_max_example(): b = compute( 'b', shape=[4], - fcompute=lambda i: arg_reduce( - extent=3, fcompute=lambda j: a[i, j], reduce_type='max' - ), + fcompute=lambda i: arg_reduce(extent=3, fcompute=lambda j: a[i, j], reduce_type='max'), ) task = Task('arg_max', inputs=[a], outputs=[b]) run_task(task, [hidet.randn([4, 3])]) diff --git a/gallery/how-to-guides/add-new-operator-rule-based.py b/gallery/developer-guides/add-new-operator-rule-based.py similarity index 97% rename from gallery/how-to-guides/add-new-operator-rule-based.py rename to gallery/developer-guides/add-new-operator-rule-based.py index 3251cc371..391ea3290 100644 --- a/gallery/how-to-guides/add-new-operator-rule-based.py +++ b/gallery/developer-guides/add-new-operator-rule-based.py @@ -50,9 +50,7 @@ def __init__(self, a: TensorNode, b: TensorNode): name='c', shape=[batch_size, m_size, n_size], fcompute=lambda p, i, j: reduce( - shape=[k_size], - fcompute=lambda k: a[p, i, k] * b[p, k, j], - reduce_type='sum', + shape=[k_size], fcompute=lambda k: a[p, i, k] * b[p, k, j], reduce_type='sum' ), ) diff --git a/gallery/how-to-guides/add-new-operator-template-based.py b/gallery/developer-guides/add-new-operator-template-based.py similarity index 88% rename from gallery/how-to-guides/add-new-operator-template-based.py rename to gallery/developer-guides/add-new-operator-template-based.py index 21435832d..e257715ab 100644 --- a/gallery/how-to-guides/add-new-operator-template-based.py +++ b/gallery/developer-guides/add-new-operator-template-based.py @@ -28,9 +28,7 @@ def __init__(self, a: TensorNode, b: TensorNode): name='c', shape=[batch_size, m_size, n_size], fcompute=lambda p, i, j: reduce( - shape=[k_size], - fcompute=lambda k: a[p, i, k] * b[p, k, j], - reduce_type='sum', + shape=[k_size], fcompute=lambda k: a[p, i, k] * b[p, k, j], reduce_type='sum' ), ) super().__init__( @@ -108,37 +106,25 @@ def batch_matmul_mma_fp16_schedule(task: BatchMatmulFp16Task) -> IRModule: with hidet.script_module() as module: @hidet.script - def load_regs_a( - smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements] - ): + def load_regs_a(smem_a: f16[block_m, block_k], regs_a: f16[4, mma_config.a_elements]): """Load A registers from shared memory.""" warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on( - warp_id - ): + for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id): for mi in range(mma_count_m): p = 0 for i, k in mma_config.a_load_map.on(lane_id): - regs_a[mi, p] = smem_a[ - wi * warp_m + mi * mma_m + i, wk * warp_k + k - ] + regs_a[mi, p] = smem_a[wi * warp_m + mi * mma_m + i, wk * warp_k + k] p += 1 @hidet.script - def load_regs_b( - smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements] - ): + def load_regs_b(smem_b: f16[block_k, block_n], regs_b: f16[8, mma_config.b_elements]): """Load B registers from shared memory.""" warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on( - warp_id - ): + for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id): for mj in range(mma_count_n): p = 0 for k, j in mma_config.b_load_map.on(lane_id): - regs_b[mj, p] = smem_b[ - wk * warp_k + k, wj * warp_n + mj * mma_n + j - ] + regs_b[mj, p] = smem_b[wk * warp_k + k, wj * warp_n + mj * mma_n + j] p += 1 @hidet.script @@ -158,18 +144,13 @@ def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size] offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n gmem_c = c[blockIdx.z, offset_m:, offset_n:] for k_round in range(warp_count_k): - for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on( - warp_id - ): + for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id): if wk == k_round: for mi, mj in repeat(mma_count_m, mma_count_n).on(0): p = 0 for i, j in mma_config.c_store_map.on(lane_id): gmem_c.write( - [ - wi * warp_m + mi * mma_m + i, - wj * warp_n + mj * mma_n + j, - ], + [wi * warp_m + mi * mma_m + i, wj * warp_n + mj * mma_n + j], regs_c[mi, mj, p], protected=True, ) @@ -177,9 +158,7 @@ def store_c(regs_c: f16[4, 8, mma_config.c_elements], c: f16[bs, m_size, n_size] @hidet.script def batch_matmul_kernel( - a: f16[bs, m_size, k_size], - b: f16[bs, k_size, n_size], - c: f16[bs, m_size, n_size], + a: f16[bs, m_size, k_size], b: f16[bs, k_size, n_size], c: f16[bs, m_size, n_size] ): """Batch matrix multiplication kernel.""" attrs.cuda.grid_dim = ( diff --git a/gallery/how-to-guides/add-operator-resolve-rule.py b/gallery/developer-guides/add-operator-resolve-rule.py similarity index 100% rename from gallery/how-to-guides/add-operator-resolve-rule.py rename to gallery/developer-guides/add-operator-resolve-rule.py diff --git a/gallery/how-to-guides/add-subgraph-rewrite-rule.py b/gallery/developer-guides/add-subgraph-rewrite-rule.py similarity index 95% rename from gallery/how-to-guides/add-subgraph-rewrite-rule.py rename to gallery/developer-guides/add-subgraph-rewrite-rule.py index f78f7cdff..29098ecbc 100644 --- a/gallery/how-to-guides/add-subgraph-rewrite-rule.py +++ b/gallery/developer-guides/add-subgraph-rewrite-rule.py @@ -16,7 +16,8 @@ .. seealso:: :class: margin - TASO :cite:`taso` systematically studies the sub-graph rewrite optimization for deep learning workloads. + `TASO `_ systematically studies the sub-graph rewrite optimization + for deep learning workloads. After the rewrite, the graph becomes more efficient as we only need to run a single kernel and the `fused` matrix multiplication usually exposes more parallelism to utilize the underlying hardware. We can also fuse multiple @@ -133,10 +134,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: # %% # We can check that the rewrite rule has been registered: -from hidet.graph.transforms import ( - registered_rewrite_rules, - clear_registered_rewrite_rules, -) +from hidet.graph.transforms import registered_rewrite_rules, clear_registered_rewrite_rules print('Registered rewrite rules:') for rule in registered_rewrite_rules(): @@ -150,9 +148,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: # last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered # rewrite rules and then register the rewrite rule we just defined: clear_registered_rewrite_rules() -register_rewrite_rule( - FuseTwoMatmulRewriteRule() -) # a second way to register the rewrite rule +register_rewrite_rule(FuseTwoMatmulRewriteRule()) # a second way to register the rewrite rule # %% # The rewrite process is done in a graph optimization pass called `subgraph_rewrite_pass`. @@ -176,8 +172,3 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: # ------- # In this tutorial, we have learned how to define and register a sub-graph rewrite rule. It is an important # component of the graph optimization framework. Hidet uses it to implement some horizontal fusion rules. - -# %% -# References -# ---------- -# .. bibliography:: diff --git a/gallery/developer-guides/hidet-script-dynamic-kernel.py b/gallery/developer-guides/hidet-script-dynamic-kernel.py index c70fac209..ee895f110 100644 --- a/gallery/developer-guides/hidet-script-dynamic-kernel.py +++ b/gallery/developer-guides/hidet-script-dynamic-kernel.py @@ -86,28 +86,18 @@ def matmul_kernel( for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on( threadIdx.x ): - global_i, global_k = ( - i + blockIdx.x * block_m_size, - k + k_tile * block_k_size, - ) + global_i, global_k = (i + blockIdx.x * block_m_size, k + k_tile * block_k_size) smem_a[i, k] = ( - a[global_i, global_k] - if global_i < m_size and global_k < k_size - else 0.0 + a[global_i, global_k] if global_i < m_size and global_k < k_size else 0.0 ) # load smem_b [block_k_size, block_n_size] from global memory for k, j in auto_map(block_k_size, block_n_size, workers=num_threads).on( threadIdx.x ): - global_k, global_j = ( - k + k_tile * block_k_size, - j + blockIdx.y * block_n_size, - ) + global_k, global_j = (k + k_tile * block_k_size, j + blockIdx.y * block_n_size) smem_b[k, j] = ( - b[global_k, global_j] - if global_k < k_size and global_j < n_size - else 0.0 + b[global_k, global_j] if global_k < k_size and global_j < n_size else 0.0 ) # synchronize all threads in the block @@ -142,15 +132,10 @@ def main(): c = hidet.zeros([m, n]).cuda() func(a, b, c, m, n, k) numpy.testing.assert_allclose( - actual=c.cpu().numpy(), - desired=a.cpu().numpy() @ b.cpu().numpy(), - rtol=1e-4, - atol=1e-4, + actual=c.cpu().numpy(), desired=a.cpu().numpy() @ b.cpu().numpy(), rtol=1e-4, atol=1e-4 ) - hidet_latency = hidet.utils.benchmark_func( - lambda: func(a, b, c, m, n, k), repeat=50 - ) + hidet_latency = hidet.utils.benchmark_func(lambda: func(a, b, c, m, n, k), repeat=50) print(f'{m}x{k}x{n}: hidet takes {hidet_latency:.2f} ms') diff --git a/gallery/getting-started/quick-start.py b/gallery/getting-started/quick-start.py index 64c4ec5c2..d0224c6bf 100644 --- a/gallery/getting-started/quick-start.py +++ b/gallery/getting-started/quick-start.py @@ -12,10 +12,9 @@ # .. note:: # :class: margin # -# Torch dynamo is a feature introduced in PyTorch 2.0, which has not been officially released yet. Please install the -# nightly build of PyTorch to use this feature. +# ``torch.compile(...)`` requires PyTorch 2.0+. # -# The easiest way to use Hidet is to use the :func:`torch.compile` function with 'hidet' as the backend, such as +# The easiest way to use Hidet is to use the :func:`torch.compile` function with ``hidet`` as the backend, such as # # .. code-block:: python # @@ -27,7 +26,7 @@ # :class: margin # # Because tf32 is enabled by default for torch's cudnn backend, the torch's precision is slightly low. -# You could disable the tf32 via ``torch.backends.cudnn.allow_tf32 = False``. See also `PyTorch TF32`_. +# You could disable the tf32 (See also `PyTorch TF32`_). # .. _PyTorch TF32: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices import hidet @@ -35,11 +34,12 @@ # take resnet18 as an example x = torch.randn(1, 3, 224, 224).cuda() -model = torch.hub.load( - 'pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False -) +model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False) model = model.cuda().eval() +# uncomment the following line to enable kernel tuning +# hidet.torch.dynamo_config.search_space(2) + # optimize the model with 'hidet' backend model_opt = torch.compile(model, backend='hidet') @@ -63,6 +63,27 @@ torch.cuda.synchronize() print('{:>10}: {:.3f} ms'.format(name, start_event.elapsed_time(end_event) / 100.0)) +# %% +# One operator can have multiple equivalent implementations (i.e., kernel programs) with different performance. We +# usually need to try different implementations for each concrete input shape to find the best one for the specific +# input shape. This process is called `kernel tuning`. To enable kernel tuning, we can use the following config in +# hidet: +# +# .. code-block:: python +# +# # 0 - no tuning, default kernel will be used +# # 1 - tuning in a small search space +# # 2 - tuning in a large search space, will take longer time and achieves better performance +# hidet.torch.dynamo_config.search_space(2) +# +# When kernel tuning is enabled, hidet can achieve the following performance on NVIDIA RTX 4090: +# +# .. code-block:: text +# +# eager: 1.176 ms +# hidet: 0.286 ms +# + # %% # Hidet provides some configurations to control the optimization of hidet backend. such as @@ -205,5 +226,6 @@ def linear_bias(x, b, c): # %% # Next Step # --------- -# It is time to learn how to use hidet in your project. A good start is to :ref:`Run ONNX Model with Hidet`. +# It is time to learn how to use hidet in your project. A good start is to :ref:`Optimize PyTorch Model` and +# :ref:`Optimize ONNX Model` with Hidet. # diff --git a/gallery/hidet-script/0-hello-world.py b/gallery/hidet-script/0-hello-world.py new file mode 100644 index 000000000..88977c1d7 --- /dev/null +++ b/gallery/hidet-script/0-hello-world.py @@ -0,0 +1,67 @@ +""" +Hello World! +============ + +In this example, we will show you how to use hidet to write a simple "Hello World" program. + +""" +# %% +# Hidet is a deep learning compiler implemented in python. Let's import it first. +import hidet + +# %% +# Hidet caches all its generated source code and binary in its cache directory. We can set the cache directory +# to a local directory ``./outs/cache`` so that you can check the generated code and binary. +hidet.option.cache_dir('./outs/cache') + +# %% +# The ``hidet.lang`` submodule implements the Hidet Script domain specific language. +# In this example, we will use ``attrs`` variable and ``printf`` function from ``hidet.lang``. +from hidet.lang import attrs, printf + +# %% +# A **script module** is a compilation unit that contains a list of functions defined in it. Inside a script module, +# we can use ``hidet.script`` to define a hidet script function. The following example defines a function named +# ``launch`` that prints a message to the standard output. + +with hidet.script_module() as script_module: + + # we use `hidet.script` to decorate a python function to define a hidet script function. + @hidet.script + def launch(): + # we use `hidet.lang.attrs` to set the attributes of the function. + # the following line specify this hidet script function is a public function. + attrs.func_kind = 'public' + + # print a message to the standard output. + printf("Hello World!\n") + + +# %% +# With the script module defined, we can build the script module with ``build()`` method. The returned ``module`` is +# an instance of ``hidet.runtime.CompiledModule``, which contains the compiled binary. +module = script_module.build() + +# %% +# We can directly call the compiled module, in this case the 'launch' function would be invoked. +# +# .. note:: +# :class: margin +# +# The printed message has not been captured by our documentation generation tool (i.e., sphinx). +# If you run the script by yourself, you will see the message printed out in your console. +module() + +# %% +# We can also explicitly specify the function to be invoked using ``module['func_name'](args)``. +module['launch']() + +# %% +# you can access the source code of the compiled module using ``module.source()``. +# +# .. note:: +# :class: margin +# +# The function in the source code has a prefix ``hidet_``, which is used to avoid name conflict with standard +# library functions. +print(module.source(color=True)) diff --git a/gallery/hidet-script/1-scalar-addition.py b/gallery/hidet-script/1-scalar-addition.py new file mode 100644 index 000000000..71fcb660f --- /dev/null +++ b/gallery/hidet-script/1-scalar-addition.py @@ -0,0 +1,42 @@ +""" +Scalar Addition +=============== +""" +# %% +# In this example, we will show you how to write a program that adds two float32 numbers. + +# %% +# We first import ``hidet`` and ``hidet.lang`` module, as well as set the cache directory. +import hidet +from hidet.lang import attrs + +hidet.option.cache_dir('./outs/cache') + +# %% +# There are a bunch of data types we can use in Hidet Script, and we can access them in ``hidet.lang.types`` module. +# Each scalar data type has both a full name and a short name. For example, the short name of ``float32`` is +# ``f32``. They are equivalent and can be used interchangeably. +from hidet.lang.types import f32 + + +# %% +# In the script function, we defined two parameters ``a`` and ``b`` with data type ``f32``. The return value of the +# function is also ``f32``. In hidet script, it is **required** to annotate the data type of each parameter. If the +# return type is not annotated, it will be treated as ``void`` data type. +with hidet.script_module() as script_module: + + @hidet.script + # In the following example, the datatype of a and b is 32-bit floating point number (f32), + # and the function returns a f32 number. + def launch(a: f32, b: f32) -> f32: + attrs.func_kind = 'public' + + return a + b + + +module = script_module.build() + +# %% +# We can invoke the compiled module with two float32 numbers as arguments, and it will return the sum of the two +# numbers. +print(module(3.0, 4.0)) diff --git a/gallery/hidet-script/2-vector-addition.py b/gallery/hidet-script/2-vector-addition.py new file mode 100644 index 000000000..2e9d8cdd8 --- /dev/null +++ b/gallery/hidet-script/2-vector-addition.py @@ -0,0 +1,43 @@ +""" +Vector Addition +=============== +""" +# %% +# In this example, we will show you how to write a program that adds two float32 vectors in hidet script. +import hidet +from hidet.lang import attrs +from hidet.lang.types import f32 + +hidet.option.cache_dir('./outs/cache') + +# %% +# In the script function, we annotate the data type of parameter ``a``, ``b``, and ``c`` as ``f32[3]``, which means +# a 3-element vector of 32-bit floating point numbers. In general, we can use ``dtype[shape]`` to define a tensor with +# given shape and data type. For example, ``f32[3, 4]`` is a 3x4 float32 matrix, and ``int32[3, 4, 5]`` is a 3x4x5 int32 +# tensor. +# +# We can use ``for i in range(extent)`` to iterate over a range, where ``extent`` is the extent of the loop. +with hidet.script_module() as script_module: + + @hidet.script + def launch(a: f32[3], b: f32[3], c: f32[3]): + attrs.func_kind = 'public' + + for i in range(10): + c[i] = a[i] + b[i] + + +module = script_module.build() + +# %% +# Create the input and output tensors (on cpu, with f32 data type by default): +a = hidet.randn([3]) +b = hidet.randn([3]) +c = hidet.empty([3]) + +# %% +# Call the compiled module with the input and output tensors +module(a, b, c) +print(a) +print(b) +print(c) diff --git a/gallery/hidet-script/3-kernel-functions.py b/gallery/hidet-script/3-kernel-functions.py new file mode 100644 index 000000000..2e31103f2 --- /dev/null +++ b/gallery/hidet-script/3-kernel-functions.py @@ -0,0 +1,117 @@ +""" +Kernel Functions +================ +""" +# %% +# Besides the ``public`` function, there are other function kinds in hidet script. Currently, we support the following +# function kinds: +# +# - ``public``: a public function. The public functions in a script module will be exposed to the outside and can be +# invoked by the outside (in our case, we can call them in python). +# - ``cpu_kernel``: a kernel function on cpu. +# - ``cpu_internal``: an internal function on cpu. +# - ``cuda_kernel``: a kernel function on cuda. +# - ``cuda_internal``: an internal function on cuda. +# +# .. tip:: +# :class: margin +# +# The ``cuda_kernel`` and ``cuda_internal`` correspond to the ``__global__`` and ``__device__`` functions in CUDA. +# +# Usually, we use the ``cpu_kernel`` and ``cuda_kernel`` to define the kernel functions. The ``cpu_internal`` and +# ``cuda_internal`` are used to define the internal functions that are only used by the kernel functions. +# +# When there is only one kernel function in a script module and there is no function named ``launch``, a default +# ``launch`` function will be generated to launch the kernel function. +# + +# %% +# CPU kernel function +# ------------------- +import hidet +from hidet.lang import attrs +from hidet.lang.types import f32 + +hidet.option.cache_dir('./outs/cache') + +with hidet.script_module() as script_module: + + @hidet.script + def matmul(a: f32[16, 16], b: f32[16, 16], c: f32[16, 16]): + # specify the function kind as 'cpu_kernel' + attrs.func_kind = 'cpu_kernel' + + for i in range(16): + for j in range(16): + c[i, j] = 0.0 + for k in range(16): + c[i, j] += a[i, k] * b[k, j] + + +module = script_module.build() + +a = hidet.randn([16, 16]) +b = hidet.randn([16, 16]) +c = hidet.empty([16, 16]) + +module(a, b, c) + +# %% +# We can check the generated source code to see that the ``launch`` function is generated automatically. +print(module.source()) + + +# %% +# CUDA kernel function +# -------------------- +# We can also define a kernel function on CUDA. The following example defines a kernel function on cuda. +# +# We can access cuda primitive variables and functions in the ``hidet.lang.cuda`` module. +from hidet.lang.cuda import blockIdx, threadIdx, blockDim + +# workload size +m_size = 1024 +n_size = 1024 +k_size = 1024 + +with hidet.script_module() as script_module: + + @hidet.script + def matmul(a: f32[m_size, k_size], b: f32[k_size, n_size], c: f32[m_size, n_size]): + # specify the function kind as 'cuda_kernel' + attrs.func_kind = 'cuda_kernel' + + # specify the grid dimension and block dimension + attrs.cuda.grid_dim = (m_size + 15) // 16, (n_size + 15) // 16 + attrs.cuda.block_dim = 16, 16 + + # the coordinate of the c matrix that this thread is responsible for + i = blockIdx.x * blockDim.x + threadIdx.x + j = blockIdx.y * blockDim.y + threadIdx.y + + if i < m_size and j < n_size: + c[i, j] = 0.0 + for k in range(k_size): + c[i, j] += a[i, k] * b[k, j] + + +module = script_module.build() + +a = hidet.randn([m_size, k_size], device='cuda') +b = hidet.randn([k_size, n_size], device='cuda') +c = hidet.empty([m_size, n_size], device='cuda') + +module(a, b, c) + +# compare the result with torch.matmul +hidet.utils.assert_close(c, a.torch() @ b.torch(), atol=1e-4, rtol=1e-4) + +# %% +# We can check the generated source code: +# +# .. tip:: +# :class: margin +# +# You can find that there is no boundary checking in the kernel function. This is because hidet infers the value +# range for each index variable and finds that the if condition is always true, so it simplifies the if-statement. +print(module.source()) diff --git a/gallery/hidet-script/4-naive-matmul.py b/gallery/hidet-script/4-naive-matmul.py new file mode 100644 index 000000000..5da10e9e5 --- /dev/null +++ b/gallery/hidet-script/4-naive-matmul.py @@ -0,0 +1,66 @@ +""" +Naive Matrix Multiplication +=========================== +""" +# %% +# In this example, we will show you how to write a program that performs matrix multiplication on GPU that supports +# arbitrary input size. +import torch +import hidet +from hidet.lang import attrs +from hidet.lang.types import f32, i32, tensor_pointer +from hidet.lang.cuda import threadIdx, blockIdx + +hidet.option.cache_dir('./outs/cache') + +with hidet.script_module() as script_module: + + @hidet.script + def matmul_kernel(a_ptr: ~f32, b_ptr: ~f32, c_ptr: ~f32, m_size: i32, n_size: i32, k_size: i32): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.block_dim = 16, 16 + attrs.cuda.grid_dim = (m_size + 15) // 16, (n_size + 15) // 16 + + # define three tensor pointers that hold the shape and dtype information + a = tensor_pointer(dtype=f32, shape=[m_size, k_size], init=a_ptr) + b = tensor_pointer(dtype=f32, shape=[k_size, n_size], init=b_ptr) + c = tensor_pointer(dtype=f32, shape=[m_size, n_size], init=c_ptr) + + i = blockIdx.x * 16 + threadIdx.x + j = blockIdx.y * 16 + threadIdx.y + + if i < m_size and j < n_size: + c[i, j] = 0.0 + for k in range(k_size): + c[i, j] += a[i, k] * b[k, j] + + +module = script_module.build() + + +# %% +# Hidet compiled module can be called directly with pytorch tensors. + + +def matmul(a: torch.Tensor, b: torch.Tensor): + m_size, n_size, k_size = a.shape[0], b.shape[1], a.shape[1] + c = torch.empty([m_size, n_size], device='cuda') + module(a, b, c, m_size, n_size, k_size) + return c + + +# %% +# Run the compiled kernels with different input sizes and check the correctness of the result. +for m_size, n_size, k_size in [(234, 345, 567), (123, 456, 789)]: + a = torch.randn(m_size, k_size, device='cuda') + b = torch.randn(k_size, n_size, device='cuda') + + c1 = matmul(a, b) + c2 = torch.matmul(a, b) + + # check the correctness of the result + torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4) + + +# %% +print(module.source()) diff --git a/gallery/hidet-script/5-efficient-matmul.py b/gallery/hidet-script/5-efficient-matmul.py new file mode 100644 index 000000000..19c9c1cbf --- /dev/null +++ b/gallery/hidet-script/5-efficient-matmul.py @@ -0,0 +1,164 @@ +""" +More Efficient Matrix Multiplication +==================================== + +In this example, we show you how to write a more efficient matrix multiplication kernel on NVIDIA GPU that uses shared +memory. For simplicity, we omitted some optimizations like software pipelining (see our `paper`_ for more details). + +.. _paper: https://dl.acm.org/doi/10.1145/3575693.3575702 + + +Feel free to skip this example if you are not familiar with CUDA programming. + +""" +# %% +import torch +import hidet +from hidet.lang import attrs +from hidet.lang import float32, int32 +from hidet.lang import as_tensor_pointer, register_tensor, shared_tensor +from hidet.lang.cuda import threadIdx, blockIdx, syncthreads +from hidet.lang.mapping import spatial, auto_map +from hidet.lang.layout import row_major, local_layout + +# the hyperparameters of the kernel +warps_m, warps_n = 4, 2 # we use 4x2 warps +warp_m, warp_n = 2, 2 # each warp repeats 2x2 times +warp_map_m, warp_map_n = 2, 16 # each warp has 2x16 threads +thread_m, thread_n = 4, 4 # each thread repeats 4x4 times + +# block_size = (64, 256, 8) +block_m_size, block_n_size = ( + warps_m * warp_m * warp_map_m * thread_m, + warps_n * warp_n * warp_map_n * thread_n, +) +block_k_size = 8 +num_warps = warps_m * warps_n # 8 +num_threads = num_warps * 32 # 256 + +with hidet.lang.script_module() as script_module: + + @hidet.script + def relu(x: float32) -> float32: + return x if x > 0.0 else 0.0 + + @hidet.script + def matmul_relu_kernel( + a_ptr: ~float32, + b_ptr: ~float32, + c_ptr: ~float32, + m_size: int32, + n_size: int32, + k_size: int32, + ): + attrs.func_name = 'matmul_kernel' + attrs.cuda.block_dim = num_threads + attrs.cuda.grid_dim = ( + (m_size + block_m_size - 1) // block_m_size, + (n_size + block_n_size - 1) // block_n_size, + ) + + a = as_tensor_pointer(a_ptr, float32, [m_size, k_size]) + b = as_tensor_pointer(b_ptr, float32, [k_size, n_size]) + c = as_tensor_pointer(c_ptr, float32, [m_size, n_size]) + + # define tensors in shared memory + smem_a = shared_tensor(float32, shape=[block_m_size, block_k_size]) + smem_b = shared_tensor(float32, shape=[block_k_size, block_n_size]) + + # define the accumulation tensor in register + regs_c = register_tensor( + dtype=float32, + # shape will be inferred from the layout automatically, + # in this case, the shape is [64, 256] + layout=( + local_layout(warps_m, warps_n) + * row_major(warp_m, warp_n) + * local_layout(warp_map_m, warp_map_n) + * row_major(thread_m, thread_n) + ), + ) + + # initialize the registers + mma_mapping = ( + spatial(warps_m, warps_n) + .repeat(warp_m, warp_n) + .spatial(warp_map_m, warp_map_n) + .repeat(thread_m, thread_n) + ) + for i, j in mma_mapping.on(threadIdx.x): + regs_c[i, j] = 0.0 + + # iterate over the k tiles + num_k_tiles = (k_size + block_k_size - 1) // block_k_size + for k_tile in range(num_k_tiles): + # load smem_a [block_m_size, block_k_size] from global memory + for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(threadIdx.x): + global_i, global_k = (i + blockIdx.x * block_m_size, k + k_tile * block_k_size) + smem_a[i, k] = ( + a[global_i, global_k] if global_i < m_size and global_k < k_size else 0.0 + ) + + # load smem_b [block_k_size, block_n_size] from global memory + for k, j in auto_map(block_k_size, block_n_size, workers=num_threads).on(threadIdx.x): + global_k, global_j = (k + k_tile * block_k_size, j + blockIdx.y * block_n_size) + smem_b[k, j] = ( + b[global_k, global_j] if global_k < k_size and global_j < n_size else 0.0 + ) + + # synchronize all threads in the block + syncthreads() + + # simt matrix multiply accumulate (mma): regs_c = regs_c + smem_a @ smem_b + for i, j in mma_mapping.on(threadIdx.x): + for k in range(block_k_size): + regs_c[i, j] += smem_a[i, k] * smem_b[k, j] + + # synchronize all threads in the block + syncthreads() + + # store regs_c back to global memory + for i, j in mma_mapping.on(threadIdx.x): + global_i = i + blockIdx.x * block_m_size + global_j = j + blockIdx.y * block_n_size + if global_i < m_size and global_j < n_size: + c[global_i, global_j] = relu(regs_c[i, j]) + + +module = script_module.build() + + +def hidet_matmul_relu(a: torch.Tensor, b: torch.Tensor): + m_size, n_size, k_size = a.shape[0], b.shape[1], a.shape[1] + c = torch.empty([m_size, n_size], device='cuda') + module(a, b, c, m_size, n_size, k_size) + return c + + +def torch_matmul_relu(a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b).relu() + + +# %% +# Run the program with different input sizes. This implementation archives about 30% performance of cuBLAS kernels. +# For more efficient implementations, please refer to the `ones`_ in hidet package. +# +# .. _ones: https://github.com/hidet-org/hidet/tree/main/python/hidet/graph/ops/matmul + +for m, n, k in [(1024, 1024, 1024), (256, 256, 256), (32, 32, 32)]: + a = torch.randn(m, k, dtype=torch.float32, device='cuda') + b = torch.randn(k, n, dtype=torch.float32, device='cuda') + + c1 = hidet_matmul_relu(a, b) + c2 = torch_matmul_relu(a, b) + + torch.testing.assert_close(c1, c2, atol=1e-4, rtol=1e-4) + + hidet_latency = hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b), repeat=50) + print(f'{m}x{k}x{n}:') + print(' torch: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: torch_matmul_relu(a, b)))) + print(' hidet: {:.3f} ms'.format(hidet.utils.benchmark_func(lambda: hidet_matmul_relu(a, b)))) + +# %% +# Get the source code: +print(module.source()) diff --git a/gallery/hidet-script/README.rst b/gallery/hidet-script/README.rst new file mode 100644 index 000000000..66a235227 --- /dev/null +++ b/gallery/hidet-script/README.rst @@ -0,0 +1,2 @@ +Index +===== \ No newline at end of file diff --git a/gallery/how-to-guides/visualize-flow-graph.py b/gallery/how-to-guides/visualize-flow-graph.py index 34ed5e108..87caac5f5 100644 --- a/gallery/how-to-guides/visualize-flow-graph.py +++ b/gallery/how-to-guides/visualize-flow-graph.py @@ -34,9 +34,7 @@ def __init__(self, hidden_size=768, num_attention_heads=12): def transpose_for_scores(self, x: Tensor) -> Tensor: batch_size, seq_length, hidden_size = x.shape - x = x.reshape( - [batch_size, seq_length, self.num_attention_heads, self.attention_head_size] - ) + x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) x = x.rearrange([[0, 2], [1], [3]]) return x # [batch_size * num_attention_heads, seq_length, attention_head_size] diff --git a/gallery/tutorials/run-onnx-model.py b/gallery/tutorials/optimize-onnx-model.py similarity index 92% rename from gallery/tutorials/run-onnx-model.py rename to gallery/tutorials/optimize-onnx-model.py index cc441f304..6487ebfff 100644 --- a/gallery/tutorials/run-onnx-model.py +++ b/gallery/tutorials/optimize-onnx-model.py @@ -1,6 +1,6 @@ """ .. currentmodule:: hidet -.. _Run ONNX Model with Hidet: +.. _Optimize ONNX Model: Optimize ONNX Model =================== @@ -23,9 +23,7 @@ onnx_path = './resnet50.onnx' # load pretrained resnet50 and create a random input -torch_model = torch.hub.load( - 'pytorch/vision:v0.9.0', 'resnet50', pretrained=True, verbose=False -) +torch_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True, verbose=False) torch_model = torch_model.cuda().eval() torch_data = torch.randn([1, 3, 224, 224]).cuda() @@ -119,10 +117,7 @@ def bench_hidet_graph(graph: hidet.FlowGraph): cuda_graph = graph.cuda_graph() (output,) = cuda_graph.run([data]) np.testing.assert_allclose( - actual=output.cpu().numpy(), - desired=torch_output.cpu().numpy(), - rtol=1e-2, - atol=1e-2, + actual=output.cpu().numpy(), desired=torch_output.cpu().numpy(), rtol=1e-2, atol=1e-2 ) print(' Hidet: {:.3f} ms'.format(benchmark_func(lambda: cuda_graph.run()))) @@ -146,6 +141,16 @@ def bench_hidet_graph(graph: hidet.FlowGraph): bench_hidet_graph(graph_opt) +# %% +# When we search in space 2, we can have the following numbers on RTX 4090: +# +# .. code-block:: text +# +# PyTorch: 1.806 ms (eager mode) +# Hidet: 3.477 ms (no optimization) +# Hidet: 0.841 ms (optimization and search space 2) +# + # %% # Summary # ------- diff --git a/gallery/tutorials/optimize-pytorch-model.py b/gallery/tutorials/optimize-pytorch-model.py index 60306fb7c..ce1df17d1 100644 --- a/gallery/tutorials/optimize-pytorch-model.py +++ b/gallery/tutorials/optimize-pytorch-model.py @@ -1,4 +1,6 @@ """ +.. _Optimize PyTorch Model: + Optimize PyTorch Model ====================== @@ -69,9 +71,7 @@ import hidet x = torch.randn(1, 3, 224, 224).cuda() -model = torch.hub.load( - 'pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False -) +model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True, verbose=False) model = model.cuda().eval() with torch.no_grad(): diff --git a/python/hidet/graph/transforms/base.py b/python/hidet/graph/transforms/base.py index 364b5b85a..5f126e99a 100644 --- a/python/hidet/graph/transforms/base.py +++ b/python/hidet/graph/transforms/base.py @@ -300,6 +300,7 @@ def profile_pass_instrument(self, log_file: Optional[str] = None, print_stdout: def reduce_cuda_compile_mem(self, enable: Optional[bool] = None): """ Reduce CUDA memory used during compilation by using vcuda tensors, might incur compile time cost + Parameters ---------- enable: Optional[bool] diff --git a/scripts/lint/format.sh b/scripts/lint/format.sh index 029c092a8..b54fb58ef 100644 --- a/scripts/lint/format.sh +++ b/scripts/lint/format.sh @@ -6,4 +6,4 @@ cd $SCRIPT_DIR # run black formatter python -m black --skip-string-normalization --skip-magic-trailing-comma --line-length 120 ../../python/hidet ../../tests -python -m black --skip-string-normalization --skip-magic-trailing-comma --line-length 90 ../../gallery +python -m black --skip-string-normalization --skip-magic-trailing-comma --line-length 100 ../../gallery